Merge branch 'main' into feature/get-trace-id-from-req-headers
This commit is contained in:
commit
14e8ca5236
|
@ -2706,9 +2706,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry"
|
name = "opentelemetry"
|
||||||
version = "0.23.0"
|
version = "0.24.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76"
|
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
|
@ -2819,19 +2819,17 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry_sdk"
|
name = "opentelemetry_sdk"
|
||||||
version = "0.23.0"
|
version = "0.24.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd"
|
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-executor",
|
"futures-executor",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"glob",
|
"glob",
|
||||||
"lazy_static",
|
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.23.0",
|
"opentelemetry 0.24.0",
|
||||||
"ordered-float 4.3.0",
|
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand",
|
"rand",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
@ -4185,16 +4183,17 @@ dependencies = [
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
|
"hashbrown 0.14.5",
|
||||||
|
"hf-hub",
|
||||||
"log",
|
"log",
|
||||||
"parking_lot",
|
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.19.1",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.24.0",
|
"tracing-opentelemetry 0.25.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -4212,7 +4211,7 @@ dependencies = [
|
||||||
"tabled",
|
"tabled",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
@ -4292,7 +4291,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sysinfo",
|
"sysinfo",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
@ -4341,7 +4340,7 @@ dependencies = [
|
||||||
"slotmap",
|
"slotmap",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tonic 0.10.2",
|
"tonic 0.10.2",
|
||||||
|
@ -4392,7 +4391,7 @@ dependencies = [
|
||||||
"slotmap",
|
"slotmap",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tonic 0.10.2",
|
"tonic 0.10.2",
|
||||||
|
@ -4514,39 +4513,6 @@ version = "0.1.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tokenizers"
|
|
||||||
version = "0.19.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
|
|
||||||
dependencies = [
|
|
||||||
"aho-corasick",
|
|
||||||
"derive_builder",
|
|
||||||
"esaxx-rs",
|
|
||||||
"getrandom",
|
|
||||||
"hf-hub",
|
|
||||||
"indicatif",
|
|
||||||
"itertools 0.12.1",
|
|
||||||
"lazy_static",
|
|
||||||
"log",
|
|
||||||
"macro_rules_attribute",
|
|
||||||
"monostate",
|
|
||||||
"onig",
|
|
||||||
"paste",
|
|
||||||
"rand",
|
|
||||||
"rayon",
|
|
||||||
"rayon-cond",
|
|
||||||
"regex",
|
|
||||||
"regex-syntax 0.8.5",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"spm_precompiled",
|
|
||||||
"thiserror",
|
|
||||||
"unicode-normalization-alignments",
|
|
||||||
"unicode-segmentation",
|
|
||||||
"unicode_categories",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.20.0"
|
version = "0.20.0"
|
||||||
|
@ -4933,14 +4899,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-opentelemetry"
|
name = "tracing-opentelemetry"
|
||||||
version = "0.24.0"
|
version = "0.25.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4"
|
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.23.0",
|
"opentelemetry 0.24.0",
|
||||||
"opentelemetry_sdk 0.23.0",
|
"opentelemetry_sdk 0.24.1",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
# All the tooling for CUDA
|
|
||||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src/tgi/backends/trtllm
|
|
||||||
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
|
|
||||||
|
|
||||||
COPY . /usr/src/tgi
|
|
||||||
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
|
|
||||||
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
|
|
||||||
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
|
|
||||||
|
|
||||||
# All the tooling for Rust
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
# Include CUDA related libraries and tools to the Rust based image
|
|
||||||
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
|
|
||||||
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
|
|
||||||
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
|
|
||||||
ENV PATH=/usr/local/cuda/bin:$PATH
|
|
||||||
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
|
|
||||||
|
|
||||||
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3
|
|
|
@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||||
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||||
ENV VLLM_MOE_PADDING=0
|
ENV VLLM_MOE_PADDING=0
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=paged
|
||||||
ENV USE_PREFIX_CACHING=0
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
ENV ROCM_USE_SKINNY_GEMM=1
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
|
|
|
@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
ENV ATTENTION=paged
|
ENV ATTENTION=paged
|
||||||
ENV USE_PREFIX_CACHING=0
|
ENV PREFIX_CACHING=0
|
||||||
|
ENV PREFILL_CHUNKING=0
|
||||||
ENV CUDA_GRAPHS=0
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
|
|
@ -10,7 +10,7 @@ COPY . .
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
# CUDA dependent dependencies resolver stage
|
# CUDA dependent dependencies resolver stage
|
||||||
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
|
@ -26,6 +26,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
ninja-build \
|
ninja-build \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
python3 \
|
python3 \
|
||||||
|
python3-dev \
|
||||||
python3-setuptools \
|
python3-setuptools \
|
||||||
tar \
|
tar \
|
||||||
wget
|
wget
|
||||||
|
@ -42,7 +43,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE
|
||||||
mkdir /usr/src/mpi && \
|
mkdir /usr/src/mpi && \
|
||||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
||||||
cd /usr/src/mpi && \
|
cd /usr/src/mpi && \
|
||||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \
|
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
|
||||||
make -j all && \
|
make -j all && \
|
||||||
make install && \
|
make install && \
|
||||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
||||||
|
@ -82,10 +83,16 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
|
||||||
cd backends/trtllm && \
|
cd backends/trtllm && \
|
||||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
||||||
|
|
||||||
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||||
|
RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \
|
||||||
|
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
|
||||||
|
python3 -m pip install transformers tokenizers
|
||||||
|
|
||||||
WORKDIR /usr/local/tgi/bin
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
|
|
||||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
|
@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
|
||||||
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl localhost:3000/v1/chat/completions \
|
curl localhost:8080/v1/chat/completions \
|
||||||
-X POST \
|
-X POST \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "tgi",
|
"model": "tgi",
|
||||||
|
|
|
@ -158,7 +158,8 @@ impl Client {
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -217,8 +218,13 @@ impl Client {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
|
|
|
@ -134,11 +134,12 @@ impl ShardedClient {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
@ -245,7 +246,8 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
@ -255,7 +257,7 @@ impl Health for ShardedClient {
|
||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch, None).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,17 @@
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
|
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
|
find_program(CCACHE_EXECUTABLE "ccache")
|
||||||
|
if (CCACHE_EXECUTABLE)
|
||||||
|
message(STATUS "Using ccache")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
|
||||||
|
endif ()
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||||
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
endif ()
|
||||||
|
|
||||||
project(tgi-trtllm-backend VERSION 1.0.0)
|
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
|
|
||||||
|
@ -14,7 +26,7 @@ set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include"
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||||
|
|
||||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||||
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||||
|
|
||||||
#### External dependencies ####
|
#### External dependencies ####
|
||||||
include(cmake/fmt.cmake)
|
include(cmake/fmt.cmake)
|
||||||
|
|
|
@ -10,16 +10,17 @@ async-trait = "0.1"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
|
hashbrown = "0.14"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
log = { version = "0.4", features = [] }
|
log = { version = "0.4", features = [] }
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.15"
|
tokio-stream = "0.1.15"
|
||||||
thiserror = "1.0.62"
|
thiserror = "1.0.63"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-opentelemetry = "0.24"
|
tracing-opentelemetry = "0.25"
|
||||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
parking_lot = "0.12"
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
|
|
@ -6,7 +6,7 @@ use std::path::{absolute, PathBuf};
|
||||||
|
|
||||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||||
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||||
const CUDA_REQUIRED_VERSION: &str = "12.5";
|
const CUDA_REQUIRED_VERSION: &str = "12.6";
|
||||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||||
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
||||||
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||||
|
@ -36,7 +36,7 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
||||||
// Build the backend implementation through CMake
|
// Build the backend implementation through CMake
|
||||||
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||||
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
||||||
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
|
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
|
||||||
|
|
||||||
let mut install_path = PathBuf::from(install_path);
|
let mut install_path = PathBuf::from(install_path);
|
||||||
if !install_path.is_absolute() {
|
if !install_path.is_absolute() {
|
||||||
|
@ -81,7 +81,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
||||||
(PathBuf::from(install_path), deps_folder)
|
(PathBuf::from(install_path), deps_folder)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
|
||||||
|
let ndebug = match is_debug {
|
||||||
|
true => "1",
|
||||||
|
false => "0",
|
||||||
|
};
|
||||||
|
|
||||||
CFG.include_prefix = "backends/trtllm";
|
CFG.include_prefix = "backends/trtllm";
|
||||||
cxx_build::bridge("src/lib.rs")
|
cxx_build::bridge("src/lib.rs")
|
||||||
.static_flag(true)
|
.static_flag(true)
|
||||||
|
@ -93,9 +98,14 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||||
.include("/usr/local/tensorrt/include")
|
.include("/usr/local/tensorrt/include")
|
||||||
.file("src/ffi.cpp")
|
.file("src/ffi.cpp")
|
||||||
.std("c++20")
|
.std("c++20")
|
||||||
|
.define("NDEBUG", ndebug)
|
||||||
.compile("tgi_trtllm_backend");
|
.compile("tgi_trtllm_backend");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/json.cmake");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
|
||||||
println!("cargo:rerun-if-changed=include/backend.h");
|
println!("cargo:rerun-if-changed=include/backend.h");
|
||||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||||
println!("cargo:rerun-if-changed=include/ffi.h");
|
println!("cargo:rerun-if-changed=include/ffi.h");
|
||||||
|
@ -115,7 +125,7 @@ fn main() {
|
||||||
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
||||||
|
|
||||||
// Build the FFI layer calling the backend above
|
// Build the FFI layer calling the backend above
|
||||||
build_ffi_layer(&deps_folder);
|
build_ffi_layer(&deps_folder, is_debug);
|
||||||
|
|
||||||
// Emit linkage search path
|
// Emit linkage search path
|
||||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
GIT_TAG 11.0.1
|
URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
json
|
json
|
||||||
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(json)
|
fetchcontent_makeavailable(json)
|
||||||
|
|
|
@ -11,7 +11,7 @@ endif ()
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
spdlog
|
spdlog
|
||||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
GIT_TAG v1.14.1
|
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(spdlog)
|
fetchcontent_makeavailable(spdlog)
|
||||||
|
|
|
@ -23,8 +23,9 @@ endif ()
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
trtllm
|
trtllm
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||||
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
|
GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
|
||||||
GIT_SHALLOW FALSE
|
GIT_SHALLOW FALSE
|
||||||
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(trtllm)
|
fetchcontent_makeavailable(trtllm)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
#ifndef TGI_TRTLLM_BACKEND_H
|
#ifndef TGI_TRTLLM_BACKEND_H
|
||||||
#define TGI_TRTLLM_BACKEND_H
|
#define TGI_TRTLLM_BACKEND_H
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <span>
|
#include <span>
|
||||||
|
@ -19,16 +20,33 @@
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
namespace tle = tensorrt_llm::executor;
|
namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
|
|
||||||
|
#define CAST_SIZETYPE(x) static_cast<tle::SizeType32>(x)
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
using RequestId = tle::IdType;
|
using RequestId = tle::IdType;
|
||||||
using TokenId = tle::TokenIdType;
|
using TokenId = tle::TokenIdType;
|
||||||
|
|
||||||
|
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
|
||||||
|
constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
|
||||||
|
"Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
|
||||||
|
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
|
||||||
|
"Submitting inference [{}] to the executor ({:d} already in-flight)");
|
||||||
|
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
|
||||||
|
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize all the components required by TRTLLM.
|
* Initialize all the components required by TRTLLM.
|
||||||
* It is required to call this function before attempting to load any engine
|
* It is required to call this function before attempting to load any engine
|
||||||
*/
|
*/
|
||||||
void InitializeBackend();
|
void InitializeBackend();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize logging mechanism
|
||||||
|
*/
|
||||||
|
void InitializeLogging();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param config TensorRT-LLM configuration object
|
* @param config TensorRT-LLM configuration object
|
||||||
|
@ -37,6 +55,14 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param worldSize
|
||||||
|
* @param workerPath
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the sampling configuration from the parameters provided by TGI
|
* Get the sampling configuration from the parameters provided by TGI
|
||||||
* @param topK
|
* @param topK
|
||||||
|
@ -54,7 +80,15 @@ namespace huggingface::tgi::backends {
|
||||||
float_t repetition_penalty,
|
float_t repetition_penalty,
|
||||||
float_t frequency_penalty,
|
float_t frequency_penalty,
|
||||||
uint64_t seed
|
uint64_t seed
|
||||||
);
|
) noexcept;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempt to retrieve the
|
||||||
|
* @param generationConfigPath
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::optional<std::list<std::vector<TokenId>>>
|
||||||
|
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -64,18 +98,16 @@ namespace huggingface::tgi::backends {
|
||||||
const json config;
|
const json config;
|
||||||
tle::Executor executor;
|
tle::Executor executor;
|
||||||
|
|
||||||
|
/** Frequently accessed variables cached here **/
|
||||||
|
uint32_t maxNumTokens;
|
||||||
|
std::list<std::vector<TokenId>> stopWords;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit TensorRtLlmBackend(
|
explicit TensorRtLlmBackend(
|
||||||
const std::filesystem::path &engineFolder,
|
const std::filesystem::path &engineFolder,
|
||||||
const std::filesystem::path &executorWorker
|
const std::filesystem::path &executorWorker
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Indicate if the backend is ready to accept incoming request
|
|
||||||
* @return true if ready, false otherwise
|
|
||||||
*/
|
|
||||||
[[nodiscard]] bool IsReady() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Query the executor for the number of token available for pulling
|
* Query the executor for the number of token available for pulling
|
||||||
* @return
|
* @return
|
||||||
|
@ -88,32 +120,23 @@ namespace huggingface::tgi::backends {
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
* @param repetition_penalty
|
* @param repetitionPenalty
|
||||||
* @param frequency_penalty
|
* @param frequencyPenalty
|
||||||
* @param seed
|
* @param seed
|
||||||
* @return Request id related to this generation for reference
|
* @return Request id related to this generation for reference
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] RequestId Submit(
|
[[nodiscard]] RequestId Submit(
|
||||||
const std::vector<TokenId> &tokens,
|
const std::vector<TokenId> &tokens,
|
||||||
|
uint32_t maxNewTokens,
|
||||||
int32_t topK,
|
int32_t topK,
|
||||||
float_t topP,
|
float_t topP,
|
||||||
float_t temperature,
|
float_t temperature,
|
||||||
float_t repetition_penalty,
|
float_t repetitionPenalty,
|
||||||
float_t frequency_penalty,
|
float_t frequencyPenalty,
|
||||||
uint64_t seed
|
uint64_t seed
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
|
||||||
*
|
|
||||||
* @param requestId The request id to poll the generation results
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
std::vector<tle::Response> Poll(RequestId requestId);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stop the underlying executor
|
|
||||||
*/
|
|
||||||
void Shutdown();
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,20 +5,31 @@
|
||||||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
class TensorRtLlmBackendImpl;
|
class TensorRtLlmBackendImpl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Template to support returning error from TllmException back to Rust in a Result<>
|
||||||
|
#include <tensorrt_llm/common/tllmException.h>
|
||||||
|
|
||||||
|
namespace rust::behavior {
|
||||||
|
template<typename Try, typename Fail>
|
||||||
|
static void trycatch(Try &&func, Fail &&fail) noexcept try {
|
||||||
|
func();
|
||||||
|
} catch (tensorrt_llm::common::TllmException &e) {
|
||||||
|
fail(e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#include "backends/trtllm/src/lib.rs.h"
|
#include "backends/trtllm/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
// struct GenerationContext;
|
|
||||||
|
|
||||||
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||||
public:
|
public:
|
||||||
/***
|
/***
|
||||||
|
@ -28,15 +39,10 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||||
|
|
||||||
/***
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
bool IsReady() const;
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param tokens
|
* @param tokens
|
||||||
|
* @param maxNewTokens
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
|
@ -47,21 +53,15 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
uint64_t
|
uint64_t
|
||||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
|
||||||
|
int32_t topK, float_t topP, float_t temperature,
|
||||||
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param requestId
|
|
||||||
* @param ctx
|
|
||||||
* @param callback
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
size_t StreamTokens(
|
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
|
||||||
const RequestId requestId,
|
|
||||||
huggingface::tgi::backends::GenerationContext *ctx,
|
|
||||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
|
||||||
huggingface::tgi::backends::GenerationStep)> callback);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/***
|
/***
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
namespace huggingface::hardware::cuda {
|
namespace huggingface::hardware::cuda {
|
||||||
|
|
||||||
#define AMPERE_SM_MAJOR 8
|
#define AMPERE_SM_MAJOR 8
|
||||||
#define HOPPER_SM_MAJOR 8
|
#define HOPPER_SM_MAJOR 9
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
||||||
|
@ -23,9 +23,9 @@ namespace huggingface::hardware::cuda {
|
||||||
int32_t major;
|
int32_t major;
|
||||||
int32_t minor;
|
int32_t minor;
|
||||||
|
|
||||||
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
|
[[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
|
||||||
|
|
||||||
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
|
[[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
|
||||||
};
|
};
|
||||||
|
|
||||||
CudaComputeCapabilities GetCudaComputeCapabilities() {
|
CudaComputeCapabilities GetCudaComputeCapabilities() {
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include <fmt/ranges.h>
|
#include <fmt/ranges.h>
|
||||||
|
@ -7,11 +8,33 @@
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
#include "hardware.h"
|
#include "hardware.h"
|
||||||
|
|
||||||
|
|
||||||
|
void huggingface::tgi::backends::InitializeLogging() {
|
||||||
|
#ifdef NDEBUG
|
||||||
|
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
|
||||||
|
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
|
||||||
|
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
|
||||||
|
return std::tolower(c);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (log_level == "debug")
|
||||||
|
spdlog::set_level(spdlog::level::debug);
|
||||||
|
else
|
||||||
|
spdlog::set_level(spdlog::level::info);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
spdlog::set_level(spdlog::level::debug);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void huggingface::tgi::backends::InitializeBackend() {
|
void huggingface::tgi::backends::InitializeBackend() {
|
||||||
SPDLOG_INFO("Initializing Backend...");
|
SPDLOG_INFO("Initializing Backend...");
|
||||||
nvmlInit_v2();
|
nvmlInit_v2();
|
||||||
initTrtLlmPlugins();
|
initTrtLlmPlugins();
|
||||||
|
|
||||||
|
InitializeLogging();
|
||||||
|
|
||||||
|
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
|
||||||
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
||||||
if (numGpus.has_value()) {
|
if (numGpus.has_value()) {
|
||||||
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||||
|
@ -20,47 +43,49 @@ void huggingface::tgi::backends::InitializeBackend() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
|
tle::ParallelConfig
|
||||||
|
huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
|
||||||
|
auto mode = tle::CommunicationMode::kLEADER;
|
||||||
|
std::optional<tle::OrchestratorConfig> orchestratorConfig = std::nullopt;
|
||||||
|
|
||||||
|
if (worldSize > 1) {
|
||||||
|
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||||
|
mode = tle::CommunicationMode::kORCHESTRATOR;
|
||||||
|
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, workerPath, nullptr, true);
|
||||||
|
} else {
|
||||||
|
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]]
|
[[nodiscard]]
|
||||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||||
tle::ExecutorConfig execConfig(1);
|
tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
|
||||||
|
|
||||||
// Retrieve the compute capabilities to enable some options at runtime
|
// Retrieve the compute capabilities to enable some options at runtime
|
||||||
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
||||||
|
|
||||||
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
||||||
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
|
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
|
||||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
|
||||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
|
||||||
tle::CommunicationType::kMPI,
|
|
||||||
tle::CommunicationMode::kLEADER,
|
|
||||||
std::nullopt,
|
|
||||||
std::nullopt,
|
|
||||||
std::nullopt
|
|
||||||
));
|
|
||||||
} else { // Multiple engines -> using orchestrator mode (MPI involved)
|
|
||||||
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
|
||||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
|
||||||
tle::CommunicationType::kMPI,
|
|
||||||
tle::CommunicationMode::kORCHESTRATOR,
|
|
||||||
std::nullopt,
|
|
||||||
std::nullopt,
|
|
||||||
tle::OrchestratorConfig(true, workerPath, nullptr, true)
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define some configuration variables
|
// Define some configuration variables
|
||||||
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
|
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
|
||||||
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
|
execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
|
||||||
|
execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
|
||||||
return execConfig;
|
return execConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
uint32_t topK,
|
const uint32_t topK,
|
||||||
float_t topP,
|
const float_t topP,
|
||||||
float_t temperature,
|
const float_t temperature,
|
||||||
float_t repetition_penalty,
|
const float_t repetition_penalty,
|
||||||
float_t frequency_penalty,
|
const float_t frequency_penalty,
|
||||||
uint64_t seed) {
|
const uint64_t seed) noexcept {
|
||||||
|
|
||||||
return tle::SamplingConfig(
|
return tle::SamplingConfig(
|
||||||
1, // TGI only use a single beam
|
1, // TGI only use a single beam
|
||||||
topK,
|
topK,
|
||||||
|
@ -78,69 +103,101 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
|
||||||
|
huggingface::tgi::backends::GetStopWordsFromConfig(
|
||||||
|
const std::filesystem::path &generationConfigPath) noexcept {
|
||||||
|
if (exists(generationConfigPath)) {
|
||||||
|
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
|
||||||
|
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
|
||||||
|
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
|
||||||
|
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
|
||||||
|
|
||||||
|
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
|
||||||
|
return {tokenIdObj.template get<tle::TokenIdType>()};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
|
||||||
|
return stopWords;
|
||||||
|
} else {
|
||||||
|
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
const std::filesystem::path &enginesFolder,
|
const std::filesystem::path &enginesFolder,
|
||||||
const std::filesystem::path &executorWorker
|
const std::filesystem::path &executorWorker
|
||||||
) :
|
) :
|
||||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||||
executor(
|
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||||
enginesFolder,
|
GetExecutorConfig(config, executorWorker.string())) {
|
||||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
|
||||||
GetExecutorConfig(config, executorWorker.string()
|
|
||||||
)) {
|
|
||||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
|
||||||
return executor.canEnqueueRequests();
|
|
||||||
|
// Ensure we have enough GPUs on the system
|
||||||
|
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
|
||||||
|
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
|
||||||
|
if (numGpus < worldSize) {
|
||||||
|
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
|
||||||
|
// todo : raise exception to catch on rust side
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache variables
|
||||||
|
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||||
|
|
||||||
|
// Attempt to discover stopWords from the generation_config.json
|
||||||
|
const auto generationConfigPath = enginesFolder / "generation_config.json";
|
||||||
|
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||||
|
#ifdef NDEBUG
|
||||||
return executor.getNumResponsesReady();
|
return executor.getNumResponsesReady();
|
||||||
|
#else
|
||||||
|
const auto numResponses = executor.getNumResponsesReady();
|
||||||
|
if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
|
||||||
|
return numResponses;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const std::vector<tle::TokenIdType> &tokens,
|
const std::vector<tle::TokenIdType> &tokens,
|
||||||
|
const uint32_t maxNewTokens,
|
||||||
const int32_t topK,
|
const int32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
const float_t temperature,
|
const float_t temperature,
|
||||||
const float_t repetition_penalty,
|
const float_t repetitionPenalty,
|
||||||
const float_t frequency_penalty,
|
const float_t frequencyPenalty,
|
||||||
const uint64_t seed
|
const uint64_t seed
|
||||||
) {
|
) {
|
||||||
#ifdef NDEBUG
|
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
|
||||||
SPDLOG_DEBUG(
|
#ifndef NDEBUG
|
||||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
{
|
||||||
tokens.size(),
|
const auto &iterations = executor.getLatestIterationStats();
|
||||||
executor.getLatestIterationStats().back().numActiveRequests
|
const auto &lastIteration = iterations.front();
|
||||||
);
|
|
||||||
#else
|
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
|
||||||
SPDLOG_DEBUG(
|
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
|
||||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
||||||
fmt::join(tokens, ", "),
|
}
|
||||||
executor.getLatestIterationStats().front().numActiveRequests
|
|
||||||
);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
|
||||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
// Build the request
|
||||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
|
||||||
return executor.enqueueRequest(
|
request.setStopWords(stopWords);
|
||||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
|
||||||
|
// Submit to the executor for batching
|
||||||
|
return executor.enqueueRequest(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Generated tokens result must be used")]]
|
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
||||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
return executor.awaitResponses();
|
||||||
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
|
||||||
return executor.awaitResponses(requestId);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
|
||||||
SPDLOG_INFO("Shutting down executor");
|
|
||||||
executor.shutdown();
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,13 @@
|
||||||
|
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
TRT_VER="10.2.0.19"
|
TRT_VER_BASE="10.4.0"
|
||||||
CUDA_VER="12.5"
|
TRT_VER_FULL="${TRT_VER_BASE}.26"
|
||||||
CUDNN_VER="9.2.1.18-1"
|
CUDA_VER="12.6"
|
||||||
NCCL_VER="2.22.3-1+cuda12.5"
|
CUDNN_VER="9.5.0.50-1"
|
||||||
CUBLAS_VER="12.5.3.2-1"
|
NCCL_VER="2.22.3-1+cuda12.6"
|
||||||
NVRTC_VER="12.5.82-1"
|
CUBLAS_VER="12.6.3.3-1"
|
||||||
|
NVRTC_VER="12.6.77-1"
|
||||||
|
|
||||||
for i in "$@"; do
|
for i in "$@"; do
|
||||||
case $i in
|
case $i in
|
||||||
|
@ -32,8 +33,9 @@ install_ubuntu_requirements() {
|
||||||
ARCH=$(uname -m)
|
ARCH=$(uname -m)
|
||||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||||
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||||
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
|
||||||
dpkg -i cuda-keyring_1.0-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
||||||
|
@ -71,7 +73,7 @@ install_centos_requirements() {
|
||||||
install_tensorrt() {
|
install_tensorrt() {
|
||||||
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||||
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||||
TRT_CUDA_VERSION="12.5"
|
TRT_CUDA_VERSION="12.6"
|
||||||
|
|
||||||
if [ -z "$RELEASE_URL_TRT" ];then
|
if [ -z "$RELEASE_URL_TRT" ];then
|
||||||
ARCH=${TRT_TARGETARCH}
|
ARCH=${TRT_TARGETARCH}
|
||||||
|
@ -79,12 +81,12 @@ install_tensorrt() {
|
||||||
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||||
fi
|
fi
|
||||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
|
||||||
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||||
rm -rf /tmp/TensorRT.tar
|
rm -rf /tmp/TensorRT.tar
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,330 +0,0 @@
|
||||||
use std::future::Future;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::pin::{pin, Pin};
|
|
||||||
use std::str::FromStr;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::{Arc, OnceLock};
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use cxx::UniquePtr;
|
|
||||||
use log::{error, warn};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
|
||||||
use tokio::time::{sleep, Instant};
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use tokio_stream::{Stream, StreamExt};
|
|
||||||
use tracing::{instrument, span, Level};
|
|
||||||
|
|
||||||
// use tokio::sync::RwLock;
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
|
||||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
|
||||||
|
|
||||||
// Value used to poll the state of the generation stream
|
|
||||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
|
||||||
|
|
||||||
pub(crate) struct Generation {
|
|
||||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
|
||||||
done: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Holds the user provided input to be executed along with a channel allowing
|
|
||||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
|
||||||
pub struct GenerationContext {
|
|
||||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
done: Arc<AtomicBool>,
|
|
||||||
queued: Instant,
|
|
||||||
start: Option<Instant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for Generation {
|
|
||||||
type Item = usize;
|
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
|
||||||
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
|
||||||
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
|
||||||
});
|
|
||||||
|
|
||||||
if !self.done.load(Ordering::Relaxed) {
|
|
||||||
let backend = pin!(self.executor.read());
|
|
||||||
let status = match backend.poll(ctx) {
|
|
||||||
Poll::Ready(executor_r) => {
|
|
||||||
let ready = executor_r.num_responses_ready();
|
|
||||||
if ready == 0 {
|
|
||||||
Poll::Pending
|
|
||||||
} else {
|
|
||||||
Poll::Ready(Some(ready))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
};
|
|
||||||
|
|
||||||
let waker = ctx.waker().clone();
|
|
||||||
tokio::spawn(async {
|
|
||||||
sleep(Duration::from_micros(*interval)).await;
|
|
||||||
waker.wake();
|
|
||||||
});
|
|
||||||
|
|
||||||
status
|
|
||||||
} else {
|
|
||||||
Poll::Ready(None) // end of stream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
|
||||||
(1, None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
|
||||||
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
|
||||||
|
|
||||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
|
||||||
pub struct TensorRtLlmBackend {
|
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
|
|
||||||
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
|
|
||||||
// the number of available tokens (read only) in the Generation stream
|
|
||||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorRtLlmBackend {
|
|
||||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
engine_folder: P,
|
|
||||||
executor_worker_path: PP,
|
|
||||||
) -> Result<Self, TensorRtLlmBackendError> {
|
|
||||||
Ok(TensorRtLlmBackend {
|
|
||||||
tokenizer: Arc::new(tokenizer),
|
|
||||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
|
||||||
engine_folder.as_ref().to_str().unwrap(),
|
|
||||||
executor_worker_path.as_ref().to_str().unwrap(),
|
|
||||||
))),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
|
||||||
if request.top_n_tokens > 1 {
|
|
||||||
return Err(InferError::ValidationError(
|
|
||||||
ValidationError::TopNTokensDisabled,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Is it really needed? How can it be validated before?
|
|
||||||
if request.parameters.grammar.is_some() {
|
|
||||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
|
||||||
}
|
|
||||||
|
|
||||||
match request.inputs.len() {
|
|
||||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
|
||||||
2.. => Err(InferError::GenerationError(
|
|
||||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
|
||||||
)),
|
|
||||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
|
||||||
Chunk::Text(text) => Ok(text),
|
|
||||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn generate(
|
|
||||||
&self,
|
|
||||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
top_k: u32,
|
|
||||||
top_p: f32,
|
|
||||||
temperature: f32,
|
|
||||||
repetition_penalty: f32,
|
|
||||||
frequency_penalty: f32,
|
|
||||||
seed: u64,
|
|
||||||
) {
|
|
||||||
let tokenizer = Arc::clone(&self.tokenizer);
|
|
||||||
let executor = Arc::clone(&self.backend);
|
|
||||||
|
|
||||||
// Let's push this in async context
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// Define the generation state
|
|
||||||
let mut generation = Generation {
|
|
||||||
executor: executor.clone(),
|
|
||||||
done: Arc::new(AtomicBool::new(false)),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Define the context over the generation
|
|
||||||
// TODO(asap): Do we really need so many shared-ownership?
|
|
||||||
let ctx = Box::new(GenerationContext {
|
|
||||||
sender: sender.clone(),
|
|
||||||
tokenizer,
|
|
||||||
tokens: vec![],
|
|
||||||
done: Arc::clone(&generation.done),
|
|
||||||
start: None,
|
|
||||||
queued: Instant::now(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
|
||||||
// still computation ongoing
|
|
||||||
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
|
||||||
let ctx_ = Box::leak(ctx);
|
|
||||||
|
|
||||||
// Submit the request to the batcher
|
|
||||||
let request_id = span!(Level::DEBUG, "submit")
|
|
||||||
.in_scope(|| async {
|
|
||||||
let mut handle = executor.write().await;
|
|
||||||
let request_id = handle.pin_mut().submit(
|
|
||||||
&tokens,
|
|
||||||
top_k as i32,
|
|
||||||
top_p,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
seed,
|
|
||||||
);
|
|
||||||
|
|
||||||
request_id
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
while let Some(_) = generation.next().await {
|
|
||||||
let mut executor_w = executor.write().await;
|
|
||||||
let executor = executor_w.pin_mut();
|
|
||||||
|
|
||||||
span!(Level::DEBUG, "decode")
|
|
||||||
.in_scope(|| async {
|
|
||||||
unsafe {
|
|
||||||
executor.stream_tokens(
|
|
||||||
request_id,
|
|
||||||
ctx_,
|
|
||||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
|
||||||
let inner_ctx = &mut *ctx;
|
|
||||||
|
|
||||||
// Update the timestamp at which the request started effectively
|
|
||||||
// Can be a bit off, would need to be before the callback, let's see
|
|
||||||
inner_ctx.start.get_or_insert(Instant::now());
|
|
||||||
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
|
||||||
|
|
||||||
// Ensure we are not running into errors
|
|
||||||
let parcel = if !step.has_error {
|
|
||||||
// Insert the latest generated token to the tracker
|
|
||||||
inner_ctx.tokens.push(step.token_id);
|
|
||||||
|
|
||||||
// Decode the token
|
|
||||||
let text = inner_ctx
|
|
||||||
.tokenizer
|
|
||||||
.decode(&[step.token_id], true)
|
|
||||||
.expect("Failed to decode token");
|
|
||||||
|
|
||||||
let special = inner_ctx
|
|
||||||
.tokenizer
|
|
||||||
.get_added_vocabulary()
|
|
||||||
.is_special_token(&text);
|
|
||||||
|
|
||||||
// Create the structure holding the token
|
|
||||||
let token = Token {
|
|
||||||
id: step.token_id,
|
|
||||||
text,
|
|
||||||
logprob: step.log_prob,
|
|
||||||
special,
|
|
||||||
};
|
|
||||||
|
|
||||||
if step.is_final {
|
|
||||||
let generated_text = inner_ctx
|
|
||||||
.tokenizer
|
|
||||||
.decode(&inner_ctx.tokens, true)
|
|
||||||
.expect("Failed to decode generated_tokens");
|
|
||||||
|
|
||||||
Ok(InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: generated_text,
|
|
||||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
|
||||||
finish_reason: FinishReason::EndOfSequenceToken,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
|
||||||
queued: inner_ctx.queued,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Ok(InferStreamResponse::Intermediate {
|
|
||||||
token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
error!("Error caught while decoding: {}", &step.error_msg);
|
|
||||||
Err(InferError::GenerationError(step.error_msg))
|
|
||||||
};
|
|
||||||
|
|
||||||
// Send the parcel to the client
|
|
||||||
inner_ctx
|
|
||||||
.sender
|
|
||||||
.send(parcel)
|
|
||||||
.expect("Failed to sent msg through the channel");
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Properly" free the shared context...
|
|
||||||
// TODO: clean that piece of sh** asap
|
|
||||||
unsafe {
|
|
||||||
let _ = Box::from_raw(ctx_);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Backend for TensorRtLlmBackend {
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn schedule(
|
|
||||||
&self,
|
|
||||||
request: ValidGenerateRequest,
|
|
||||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
|
||||||
// Let's add a few more validation
|
|
||||||
let input = TensorRtLlmBackend::validate(&request)?;
|
|
||||||
|
|
||||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
|
||||||
let (sender, receiver) = unbounded_channel();
|
|
||||||
|
|
||||||
// Unpack parameters
|
|
||||||
let params = &request.parameters;
|
|
||||||
|
|
||||||
// Preprocess the inputs to send to TRTLLM backend
|
|
||||||
let encoding = self
|
|
||||||
.tokenizer
|
|
||||||
.encode(input.as_str(), true)
|
|
||||||
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Generate the response
|
|
||||||
self.generate(
|
|
||||||
sender,
|
|
||||||
Vec::from(encoding.get_ids()),
|
|
||||||
params.top_k,
|
|
||||||
params.top_p,
|
|
||||||
params.temperature,
|
|
||||||
params.repetition_penalty,
|
|
||||||
params.frequency_penalty,
|
|
||||||
params.seed,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(UnboundedReceiverStream::new(receiver))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health(&self, _current_health: bool) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,9 +1,16 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum TensorRtLlmBackendError {
|
pub enum TensorRtLlmBackendError {
|
||||||
|
#[error("Provided engine folder {0} doesn't exist")]
|
||||||
|
EngineFolderDoesntExists(PathBuf),
|
||||||
|
#[error("Provided executorWorker binary path {0} doesn't exist")]
|
||||||
|
ExecutorWorkerNotFound(PathBuf),
|
||||||
|
#[error("TensorRT-LLM Runtime error: {0}")]
|
||||||
|
Runtime(String),
|
||||||
#[error("Tokenizer error: {0}")]
|
#[error("Tokenizer error: {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
|
|
|
@ -3,11 +3,13 @@
|
||||||
//
|
//
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cmath>
|
#include <algorithm>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <ranges>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
@ -20,61 +22,64 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||||
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||||
|
|
||||||
|
|
||||||
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
|
||||||
return TensorRtLlmBackend::IsReady();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
rust::Slice<const uint32_t> tokens,
|
||||||
float_t frequency_penalty, uint64_t seed) {
|
uint32_t maxNewTokens,
|
||||||
|
int32_t topK,
|
||||||
|
float_t topP,
|
||||||
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
|
uint64_t seed) {
|
||||||
|
|
||||||
// This will copy all the items from the initial slice
|
// This will copy all the items from the initial slice
|
||||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
|
||||||
return TensorRtLlmBackend::Submit(
|
return TensorRtLlmBackend::Submit(
|
||||||
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
|
||||||
const uint64_t requestId,
|
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
|
||||||
huggingface::tgi::backends::GenerationContext *ctx,
|
const auto responses = TensorRtLlmBackend::PullNewTokens();
|
||||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
|
||||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
|
||||||
|
|
||||||
size_t numTokens = 0;
|
auto steps = std::make_unique<std::vector<GenerationStep>>();
|
||||||
for (const auto &item: Poll(requestId)) {
|
steps->reserve(responses.size());
|
||||||
GenerationStep step;
|
|
||||||
if (!item.hasError()) {
|
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
|
||||||
const auto decoded = item.getResult();
|
|
||||||
|
|
||||||
const auto token = decoded.outputTokenIds[0][0];
|
#ifndef NDEBUG
|
||||||
const auto isFinal = decoded.isFinal;
|
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
|
||||||
const auto logProb = decoded.logProbs.value()[0][0];
|
#endif
|
||||||
|
|
||||||
++numTokens;
|
// Transform tle::Response to GenerationStep
|
||||||
|
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
|
||||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
const auto reqId = r.getRequestId();
|
||||||
step = huggingface::tgi::backends::GenerationStep{
|
if (!r.hasError()) {
|
||||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
const auto result = r.getResult();
|
||||||
|
return GenerationStep{
|
||||||
|
reqId,
|
||||||
|
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
||||||
|
result.logProbs.value()[0][0],
|
||||||
|
result.isFinal,
|
||||||
|
false,
|
||||||
|
std::string()
|
||||||
};
|
};
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
|
||||||
} else {
|
} else {
|
||||||
// TODO : Return rest::Result with error
|
return GenerationStep{
|
||||||
const auto what = item.getErrorMsg();
|
reqId,
|
||||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
0,
|
||||||
step = huggingface::tgi::backends::GenerationStep{
|
0.0,
|
||||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
true,
|
||||||
|
true,
|
||||||
|
std::move(r.getErrorMsg())
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
callback(std::move(ctx), std::move(step));
|
return steps;
|
||||||
}
|
|
||||||
|
|
||||||
return numTokens;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||||
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||||
|
SPDLOG_INFO("Creating TensorRT-LLM Backend");
|
||||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||||
InitializeBackend();
|
InitializeBackend();
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
pub use looper::TensorRtLlmBackendV2;
|
||||||
|
|
||||||
mod backend;
|
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
|
mod looper;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
/// of a single decoding iteration
|
/// of a single decoding iteration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct GenerationStep {
|
pub struct GenerationStep {
|
||||||
|
request_id: u64,
|
||||||
token_id: u32,
|
token_id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
@ -16,10 +18,6 @@ mod ffi {
|
||||||
error_msg: String,
|
error_msg: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "Rust" {
|
|
||||||
type GenerationContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/trtllm/src/ffi.cpp");
|
include!("backends/trtllm/src/ffi.cpp");
|
||||||
|
|
||||||
|
@ -44,10 +42,7 @@ mod ffi {
|
||||||
fn CreateTensorRtLlmBackend(
|
fn CreateTensorRtLlmBackend(
|
||||||
engine_folder: &str,
|
engine_folder: &str,
|
||||||
executor_worker: &str,
|
executor_worker: &str,
|
||||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
|
||||||
|
|
||||||
// #[rust_name = "is_ready"]
|
|
||||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
|
||||||
|
|
||||||
#[rust_name = "num_responses_ready"]
|
#[rust_name = "num_responses_ready"]
|
||||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||||
|
@ -56,23 +51,18 @@ mod ffi {
|
||||||
fn Submit(
|
fn Submit(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[u32],
|
||||||
|
max_new_tokens: u32,
|
||||||
top_k: i32,
|
top_k: i32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
repetition_penalty: f32,
|
repetition_penalty: f32,
|
||||||
frequency_penalty: f32,
|
frequency_penalty: f32,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
) -> u64;
|
) -> Result<u64>;
|
||||||
|
|
||||||
#[rust_name = "stream_tokens"]
|
#[rust_name = "pull_tokens"]
|
||||||
unsafe fn StreamTokens(
|
fn PullTokens(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
request_id: u64,
|
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
|
||||||
ctx: *mut GenerationContext,
|
|
||||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
// #[rust_name = "shutdown"]
|
|
||||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,382 @@
|
||||||
|
use std::hint;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use cxx::UniquePtr;
|
||||||
|
use hashbrown::HashMap;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
|
use tokio::sync::TryAcquireError;
|
||||||
|
use tokio::task::{spawn_blocking, JoinHandle};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
|
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidationError::{
|
||||||
|
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||||
|
};
|
||||||
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
|
||||||
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
|
use crate::utils::first_line;
|
||||||
|
|
||||||
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
|
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||||
|
struct GenerationContext {
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
start: Option<Instant>,
|
||||||
|
queued: Instant,
|
||||||
|
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
struct DecodedToken {
|
||||||
|
id: u32,
|
||||||
|
log_prob: f32,
|
||||||
|
is_final: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||||
|
type Error = InferError;
|
||||||
|
|
||||||
|
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
|
||||||
|
if !step.has_error {
|
||||||
|
Ok(Self {
|
||||||
|
id: step.token_id,
|
||||||
|
log_prob: step.log_prob,
|
||||||
|
is_final: step.is_final,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(GenerationError(step.error_msg.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
||||||
|
struct DecodedTokenContext {
|
||||||
|
token: DecodedToken,
|
||||||
|
start: Option<Instant>,
|
||||||
|
queued: Instant,
|
||||||
|
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn executor_status_looper(
|
||||||
|
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||||
|
max_inflight_requests: usize,
|
||||||
|
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||||
|
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
|
||||||
|
) {
|
||||||
|
// Track the tuple (request_id, stream) for each request
|
||||||
|
let mut in_flights =
|
||||||
|
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
|
||||||
|
|
||||||
|
// TODO: Does it need a spin-loop?
|
||||||
|
'scheduler: loop {
|
||||||
|
// Is there any request pending to be scheduled?
|
||||||
|
let awaiting_requests = waiting_requests.len();
|
||||||
|
for _ in 0..awaiting_requests {
|
||||||
|
// Retrieve all the requests
|
||||||
|
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
||||||
|
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||||
|
let request = &ctx.request;
|
||||||
|
let generation_params = &request.parameters;
|
||||||
|
let stopping_params = &request.stopping_parameters;
|
||||||
|
let input_ids = request.input_ids.as_deref();
|
||||||
|
|
||||||
|
// Submit to the TensorRT-LLM executor for scheduling
|
||||||
|
match backend.pin_mut().submit(
|
||||||
|
&input_ids.unwrap(), // This is checked beforehand in validate()
|
||||||
|
stopping_params.max_new_tokens,
|
||||||
|
generation_params.top_k as i32,
|
||||||
|
generation_params.top_p,
|
||||||
|
generation_params.temperature,
|
||||||
|
generation_params.repetition_penalty,
|
||||||
|
generation_params.frequency_penalty,
|
||||||
|
generation_params.seed,
|
||||||
|
) {
|
||||||
|
Ok(request_id) => {
|
||||||
|
// Insert the context linked to the generated request id in the tracker
|
||||||
|
debug!("[in-flight] Added {}", request_id);
|
||||||
|
ctx.start = Some(Instant::now());
|
||||||
|
in_flights.insert(request_id, ctx);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Return to the caller
|
||||||
|
let what = e.to_string();
|
||||||
|
error!(error = what.as_str(), "Failed to schedule request");
|
||||||
|
|
||||||
|
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
|
||||||
|
if let Err(_) = ctx.streamer.send(err) {
|
||||||
|
error!("Failed to send back error to the client");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if backend.num_responses_ready() > 0 {
|
||||||
|
match backend.pin_mut().pull_tokens() {
|
||||||
|
Ok(responses) => {
|
||||||
|
// Iterate through all the decoded token
|
||||||
|
for step in responses.deref() {
|
||||||
|
if let Some(ctx) = in_flights.get(&step.request_id) {
|
||||||
|
// Remove from tracked requests
|
||||||
|
let parcel =
|
||||||
|
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
||||||
|
token: dt,
|
||||||
|
start: ctx.start,
|
||||||
|
queued: ctx.queued,
|
||||||
|
channel: ctx.streamer.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Submit the work to p:the post_processor
|
||||||
|
let posted = post_processor_sender.send((step.request_id, parcel));
|
||||||
|
|
||||||
|
if posted.is_err() || step.is_final {
|
||||||
|
debug!("Removing {}", step.request_id);
|
||||||
|
let _ = in_flights.remove(&step.request_id);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Untracked request {}", step.request_id,);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ref err) => {
|
||||||
|
error!("Failed to get responses from the executor: {}.", err.what());
|
||||||
|
break 'scheduler;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hint the CPU we are spin-locking
|
||||||
|
hint::spin_loop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
max_inflight_requests: usize,
|
||||||
|
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
|
||||||
|
) {
|
||||||
|
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
|
||||||
|
|
||||||
|
'post_processor: loop {
|
||||||
|
if decoded_tokens.is_closed() {
|
||||||
|
warn!("Post processor IPC is closed, loop will exit now.");
|
||||||
|
break 'post_processor;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
||||||
|
match decoded {
|
||||||
|
Ok(ctx) => {
|
||||||
|
states
|
||||||
|
.entry(request_id)
|
||||||
|
.and_modify(|s| s.push(*&ctx.token.id))
|
||||||
|
.or_insert_with(|| {
|
||||||
|
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
|
||||||
|
state.push(*&ctx.token.id);
|
||||||
|
state
|
||||||
|
});
|
||||||
|
|
||||||
|
let out = match tokenizer.decode(&[ctx.token.id], false) {
|
||||||
|
Ok(text) => {
|
||||||
|
let is_special =
|
||||||
|
tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
let token = Token {
|
||||||
|
id: ctx.token.id,
|
||||||
|
text,
|
||||||
|
logprob: ctx.token.log_prob,
|
||||||
|
special: is_special,
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = if !ctx.token.is_final {
|
||||||
|
InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let tokens = states.remove(&request_id).unwrap();
|
||||||
|
let text = tokenizer.decode(&tokens, true);
|
||||||
|
let generated_text = GeneratedText {
|
||||||
|
text: text.unwrap(),
|
||||||
|
generated_tokens: tokens.len() as u32,
|
||||||
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text,
|
||||||
|
start: ctx.start.unwrap(),
|
||||||
|
queued: ctx.queued,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
Err(err) => Err(GenerationError(err.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(_) = ctx.channel.send(out) {
|
||||||
|
warn!("Failed to send decoded token back to the user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_err) => {
|
||||||
|
todo!("what do we do?")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
|
||||||
|
engine_folder: P,
|
||||||
|
executor_worker_path: PP,
|
||||||
|
) -> Result<(String, String), TensorRtLlmBackendError> {
|
||||||
|
// Retrieve paths as &str for the backend creation
|
||||||
|
let engine_folder = engine_folder.as_ref();
|
||||||
|
let executor_worker_path = executor_worker_path.as_ref();
|
||||||
|
|
||||||
|
// Ensure the engine folder exists
|
||||||
|
if !engine_folder.exists() {
|
||||||
|
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
|
||||||
|
|
||||||
|
error!("Path validation failed: {}", err,);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure executor worker binary exists
|
||||||
|
if !executor_worker_path.exists() {
|
||||||
|
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
|
||||||
|
|
||||||
|
error!("Path validation failed: {}", err,);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let engine_folder = String::from(
|
||||||
|
engine_folder
|
||||||
|
.to_str()
|
||||||
|
.expect("Failed to convert engine_folder to valid UTF-8"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let executor_worker_path = String::from(
|
||||||
|
executor_worker_path
|
||||||
|
.to_str()
|
||||||
|
.expect("Failed to convert executor_worker_path to valid UTF-8"),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok((engine_folder, executor_worker_path))
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
|
pub struct TensorRtLlmBackendV2 {
|
||||||
|
executor_looper: JoinHandle<()>,
|
||||||
|
post_processor_looper: JoinHandle<()>,
|
||||||
|
executor: UnboundedSender<GenerationContext>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorRtLlmBackendV2 {
|
||||||
|
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
engine_folder: P,
|
||||||
|
executor_worker_path: PP,
|
||||||
|
max_inflight_requests: usize,
|
||||||
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
|
let (engine_folder, executor_worker_path) =
|
||||||
|
ensure_paths_exist(engine_folder, executor_worker_path)?;
|
||||||
|
|
||||||
|
// Allocate the IPC layer to communicate with the backend
|
||||||
|
let (executor_sender, executor_receiver) = unbounded_channel();
|
||||||
|
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
|
||||||
|
|
||||||
|
// Create the FFI backend
|
||||||
|
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
|
||||||
|
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||||
|
|
||||||
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
|
let executor_looper = spawn_blocking(move || {
|
||||||
|
executor_status_looper(
|
||||||
|
backend,
|
||||||
|
max_inflight_requests,
|
||||||
|
executor_receiver,
|
||||||
|
post_processor_sender,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
||||||
|
let post_processor_looper = spawn_blocking(move || {
|
||||||
|
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(TensorRtLlmBackendV2 {
|
||||||
|
executor_looper,
|
||||||
|
post_processor_looper,
|
||||||
|
executor: executor_sender,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
||||||
|
if request.input_ids.is_none() {
|
||||||
|
return Err(ValidationError(UnsupportedModality("No token provided")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.top_n_tokens > 1 {
|
||||||
|
return Err(ValidationError(TopNTokensDisabled));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Is it really needed? How can it be validated before?
|
||||||
|
if request.parameters.grammar.is_some() {
|
||||||
|
return Err(ValidationError(Grammar));
|
||||||
|
}
|
||||||
|
|
||||||
|
match request.inputs.len() {
|
||||||
|
0 => Err(ValidationError(EmptyInput)),
|
||||||
|
2.. => Err(GenerationError(
|
||||||
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
|
)),
|
||||||
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
|
Chunk::Text(_) => Ok(()),
|
||||||
|
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for TensorRtLlmBackendV2 {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
inner: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
Self::validate(&inner)?;
|
||||||
|
|
||||||
|
// Open-up the stream to send tokens
|
||||||
|
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||||
|
|
||||||
|
// Send the context to the executor for scheduling
|
||||||
|
let queued = Instant::now();
|
||||||
|
match self.executor.send(GenerationContext {
|
||||||
|
request: inner,
|
||||||
|
start: None,
|
||||||
|
queued,
|
||||||
|
streamer,
|
||||||
|
}) {
|
||||||
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
|
Err(_) => Err(GenerationError(
|
||||||
|
"Failed to submit request to the backend".into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, _: bool) -> bool {
|
||||||
|
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,16 @@
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::collections::HashMap;
|
use hf_hub::api::tokio::{Api, ApiBuilder};
|
||||||
use std::path::PathBuf;
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server::get_base_tokenizer;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
|
use text_generation_router::{server, HubTokenizerConfig};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
|
@ -48,14 +54,138 @@ struct Args {
|
||||||
otlp_service_name: String,
|
otlp_service_name: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
auth_token: Option<String>,
|
auth_token: Option<String>,
|
||||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||||
executor_worker: PathBuf,
|
executor_worker: PathBuf,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_tokenizer(
|
||||||
|
tokenizer_name: &str,
|
||||||
|
tokenizer_config_path: Option<&str>,
|
||||||
|
revision: Option<&str>,
|
||||||
|
) -> Option<Tokenizer> {
|
||||||
|
// Parse Huggingface hub token
|
||||||
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
// Tokenizer instance
|
||||||
|
let local_path = Path::new(tokenizer_name);
|
||||||
|
|
||||||
|
// Shared API builder initialization
|
||||||
|
let api_builder = || {
|
||||||
|
let mut builder = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_token(authorization_token);
|
||||||
|
|
||||||
|
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decide if we need to use the API based on the revision and local path
|
||||||
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
let api = if use_api {
|
||||||
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
|
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
||||||
|
.map_err(|_| ())
|
||||||
|
.map(|cache_dir| Cache::new(cache_dir.into()))
|
||||||
|
.unwrap_or_else(|_| Cache::default());
|
||||||
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
|
Type::Cache(cache)
|
||||||
|
} else {
|
||||||
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
match api_builder().build() {
|
||||||
|
Ok(api) => Type::Api(api),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
|
Type::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Type::None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer and model info
|
||||||
|
let (
|
||||||
|
tokenizer_filename,
|
||||||
|
_config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
_preprocessor_config_filename,
|
||||||
|
_processor_config_filename,
|
||||||
|
) = match api {
|
||||||
|
Type::None => (
|
||||||
|
Some(local_path.join("tokenizer.json")),
|
||||||
|
Some(local_path.join("config.json")),
|
||||||
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
|
Some(local_path.join("processor_config.json")),
|
||||||
|
),
|
||||||
|
Type::Api(api) => {
|
||||||
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.unwrap_or_else(|| "main").to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
preprocessor_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type::Cache(cache) => {
|
||||||
|
let repo = cache.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main").to_string(),
|
||||||
|
));
|
||||||
|
(
|
||||||
|
repo.get("tokenizer.json"),
|
||||||
|
repo.get("config.json"),
|
||||||
|
repo.get("tokenizer_config.json"),
|
||||||
|
repo.get("preprocessor_config.json"),
|
||||||
|
repo.get("processor_config.json"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
|
} else {
|
||||||
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
};
|
||||||
|
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -83,10 +213,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
otlp_service_name,
|
otlp_service_name,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
messages_api_enabled,
|
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
auth_token,
|
auth_token,
|
||||||
executor_worker,
|
executor_worker,
|
||||||
|
usage_stats,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -124,18 +254,26 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run server
|
// Create the backend
|
||||||
let tokenizer = Tokenizer::from_pretrained(
|
let tokenizer = get_tokenizer(
|
||||||
tokenizer_name.clone(),
|
&tokenizer_name,
|
||||||
Some(FromPretrainedParameters {
|
tokenizer_config_path.as_deref(),
|
||||||
revision: revision.clone().unwrap_or(String::from("main")),
|
revision.as_deref(),
|
||||||
user_agent: HashMap::new(),
|
|
||||||
auth_token,
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
.await
|
||||||
|
.expect("Failed to retrieve tokenizer implementation");
|
||||||
|
|
||||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
||||||
|
let backend = TensorRtLlmBackendV2::new(
|
||||||
|
tokenizer,
|
||||||
|
model_id,
|
||||||
|
executor_worker,
|
||||||
|
max_concurrent_requests,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
info!("Successfully created backend");
|
||||||
|
|
||||||
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
@ -145,7 +283,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
None,
|
auth_token,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
@ -155,11 +293,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
false,
|
false,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
messages_api_enabled,
|
|
||||||
true,
|
true,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
false,
|
usage_stats,
|
||||||
false,
|
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
///
|
||||||
|
/// Extract the first line of the provided string reference.
|
||||||
|
/// If there is no lines in the buffer, it returns a string
|
||||||
|
/// which content is defined by the content of `fail`
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `s`: The string buffer to extract the first-line from
|
||||||
|
/// * `fail`: A string content which is returned if no lines are
|
||||||
|
/// present in `s`
|
||||||
|
///
|
||||||
|
/// returns: String
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
|
||||||
|
/// first_line(s, "No line in string");
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn first_line(s: &str, fail: &str) -> String {
|
||||||
|
s.lines().next().unwrap_or(fail).to_string()
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -36,18 +36,14 @@ impl BackendV2 {
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
|
||||||
attention
|
let block_size = match attention.as_str() {
|
||||||
.parse()
|
"flashinfer" => 1,
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
"flashdecoding" => 256,
|
||||||
} else {
|
"paged" => 16,
|
||||||
Attention::Paged
|
_ => unreachable!(),
|
||||||
};
|
|
||||||
let block_size = if attention == Attention::FlashDecoding {
|
|
||||||
256
|
|
||||||
} else {
|
|
||||||
16
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,8 @@ struct Args {
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
trust_remote_code: bool,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -63,8 +65,6 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
api_key,
|
api_key,
|
||||||
json_output,
|
json_output,
|
||||||
|
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
|
||||||
/// Batching and inference logic
|
/// Batching and inference logic
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
|
||||||
|
};
|
||||||
use crate::queue::{Entry, Queue};
|
use crate::queue::{Entry, Queue};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
@ -31,27 +33,22 @@ impl BackendV3 {
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
requires_padding: bool,
|
shard_info: InfoResponse,
|
||||||
window_size: Option<u32>,
|
|
||||||
speculate: u32,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let prefix_caching =
|
if shard_info.support_chunking {
|
||||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
}
|
||||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
|
||||||
|
|
||||||
let attention: Attention = attention
|
let block_size = shard_info.block_size;
|
||||||
.parse()
|
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
|
||||||
let block_size = attention.block_size();
|
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
shard_info.requires_padding,
|
||||||
block_size,
|
block_size,
|
||||||
prefix_caching,
|
shard_info.use_prefix_caching,
|
||||||
window_size,
|
shard_info.window_size,
|
||||||
speculate,
|
shard_info.speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
shard_info.support_chunking,
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
@ -63,6 +60,7 @@ impl BackendV3 {
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
shard_info.support_chunking,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
));
|
));
|
||||||
|
@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
support_chunking: bool,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
notifier: Arc<Notify>,
|
notifier: Arc<Notify>,
|
||||||
) {
|
) {
|
||||||
|
@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
@ -158,10 +157,24 @@ pub(crate) async fn batching_task(
|
||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let current_tokens = batch.current_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
|
||||||
|
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||||
|
// Since the next batch will be concatenated with the current batch,
|
||||||
|
// the current batch tokens must be subtracted to the prefill budget
|
||||||
|
let prefill_token_budget =
|
||||||
|
max_batch_prefill_tokens.saturating_sub(current_tokens);
|
||||||
|
// We can ignore min_size and max_size
|
||||||
|
// Models than rely on max_size cannot support chunking
|
||||||
|
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||||
|
// bound, making min_size useless.
|
||||||
|
(None, None, prefill_token_budget)
|
||||||
|
} else {
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
// to add a new batch even though its size might be small
|
// to add a new batch even though its size might be small
|
||||||
|
@ -173,24 +186,34 @@ pub(crate) async fn batching_task(
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
|
||||||
let max_size =
|
let max_size =
|
||||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
|
(min_size, max_size, max_batch_prefill_tokens)
|
||||||
|
};
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
.increment(1);
|
.increment(1);
|
||||||
|
} else {
|
||||||
|
let counter = if support_chunking {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
|
||||||
} else {
|
} else {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
.increment(1);
|
};
|
||||||
|
counter.increment(1);
|
||||||
}
|
}
|
||||||
|
let cached_batch = if support_chunking {
|
||||||
|
// Concat current batch to the new one
|
||||||
|
batches.pop()
|
||||||
|
} else {
|
||||||
|
// Request are waiting only if we don't support chunking
|
||||||
entries.iter_mut().for_each(|(_, entry)| {
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
// Create a new span to add the info that this entry is waiting
|
// Create a new span to add the info that this entry is waiting
|
||||||
// because a new batch is being computed
|
// because a new batch is being computed
|
||||||
|
@ -201,17 +224,23 @@ pub(crate) async fn batching_task(
|
||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_waiting_span);
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
});
|
});
|
||||||
|
None
|
||||||
|
};
|
||||||
|
entries.extend(new_entries);
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let new_cached_batch =
|
||||||
|
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
entries.extend(new_entries);
|
|
||||||
batches.push(new_cached_batch);
|
batches.push(new_cached_batch);
|
||||||
|
} else if support_chunking {
|
||||||
|
// New cached batch is empty, no work left
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
|
||||||
async fn prefill(
|
async fn prefill(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<CachedBatch> {
|
) -> Option<CachedBatch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch, cached_batch).await {
|
||||||
Ok((generations, next_batch, timings)) => {
|
Ok((generations, next_batch, timings)) => {
|
||||||
let start_filtering_time = Instant::now();
|
let start_filtering_time = Instant::now();
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
|
@ -259,6 +289,10 @@ async fn prefill(
|
||||||
// Filter next batch and remove requests that were stopped
|
// Filter next batch and remove requests that were stopped
|
||||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||||
.record(timings.forward.as_secs_f64());
|
.record(timings.forward.as_secs_f64());
|
||||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
|
|
|
@ -158,7 +158,8 @@ impl Client {
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -217,13 +218,23 @@ impl Client {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
cached_batch,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
response.batch,
|
response.batch,
|
||||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
PrefillTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,14 +263,16 @@ impl Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PrefillTimings {
|
pub struct PrefillTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
pub forward: Duration,
|
pub forward: Duration,
|
||||||
pub decode: Duration,
|
pub decode: Duration,
|
||||||
pub total: Duration,
|
pub total: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PrefillTimings {
|
impl PrefillTimings {
|
||||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
forward: Duration::from_nanos(forward_ns),
|
forward: Duration::from_nanos(forward_ns),
|
||||||
decode: Duration::from_nanos(decode_ns),
|
decode: Duration::from_nanos(decode_ns),
|
||||||
total: Duration::from_nanos(total_ns),
|
total: Duration::from_nanos(total_ns),
|
||||||
|
|
|
@ -29,15 +29,6 @@ pub trait Health {
|
||||||
async fn model_health(&self) -> Result<()>;
|
async fn model_health(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ShardInfo {
|
|
||||||
pub requires_padding: bool,
|
|
||||||
pub dtype: String,
|
|
||||||
pub device_type: String,
|
|
||||||
pub window_size: Option<u32>,
|
|
||||||
pub speculate: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
#[error("Could not connect to Text Generation server: {0}")]
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::client::{ClientError, Result};
|
use crate::client::Health;
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::client::{Health, ShardInfo};
|
use crate::client::{ClientError, Result};
|
||||||
|
|
||||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
|
@ -49,13 +49,13 @@ impl ShardedClient {
|
||||||
|
|
||||||
/// Get the model info
|
/// Get the model info
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.info())
|
.map(|client| client.info())
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GRPC health check
|
/// GRPC health check
|
||||||
|
@ -135,11 +135,12 @@ impl ShardedClient {
|
||||||
pub async fn prefill(
|
pub async fn prefill(
|
||||||
&mut self,
|
&mut self,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
|
||||||
.collect();
|
.collect();
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
@ -194,18 +195,6 @@ impl ShardedClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<InfoResponse> for ShardInfo {
|
|
||||||
fn from(value: InfoResponse) -> Self {
|
|
||||||
Self {
|
|
||||||
requires_padding: value.requires_padding,
|
|
||||||
dtype: value.dtype,
|
|
||||||
device_type: value.device_type,
|
|
||||||
window_size: value.window_size,
|
|
||||||
speculate: value.speculate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Health for ShardedClient {
|
impl Health for ShardedClient {
|
||||||
async fn device_health(&self) -> Result<()> {
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
@ -246,8 +235,9 @@ impl Health for ShardedClient {
|
||||||
// Block 0 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
|
chunk_len: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
@ -256,7 +246,7 @@ impl Health for ShardedClient {
|
||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
max_blocks: 1,
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch, None).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,14 @@ pub struct BackendInfo {
|
||||||
pub max_waiting_tokens: usize,
|
pub max_waiting_tokens: usize,
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: Option<usize>,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub support_chunking: bool,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub prefix_caching: bool,
|
||||||
|
#[schema(example = "flashinfer")]
|
||||||
|
pub attention_impl: String,
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub block_size: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
@ -110,6 +118,10 @@ pub async fn connect_backend(
|
||||||
model_device_type: shard_info.device_type.clone(),
|
model_device_type: shard_info.device_type.clone(),
|
||||||
model_dtype: shard_info.dtype.clone(),
|
model_dtype: shard_info.dtype.clone(),
|
||||||
speculate: shard_info.speculate as usize,
|
speculate: shard_info.speculate as usize,
|
||||||
|
support_chunking: shard_info.support_chunking,
|
||||||
|
prefix_caching: shard_info.use_prefix_caching,
|
||||||
|
attention_impl: shard_info.attention_impl.clone(),
|
||||||
|
block_size: shard_info.block_size,
|
||||||
};
|
};
|
||||||
|
|
||||||
let backend = BackendV3::new(
|
let backend = BackendV3::new(
|
||||||
|
@ -119,9 +131,7 @@ pub async fn connect_backend(
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
shard_info.requires_padding,
|
shard_info,
|
||||||
shard_info.window_size,
|
|
||||||
shard_info.speculate,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
tracing::info!("Using backend V3");
|
tracing::info!("Using backend V3");
|
||||||
|
|
|
@ -44,6 +44,8 @@ struct Args {
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
#[clap(long, env, value_enum)]
|
||||||
|
trust_remote_code: bool,
|
||||||
#[clap(default_value = "2", long, env)]
|
#[clap(default_value = "2", long, env)]
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -63,8 +65,6 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
|
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
api_key,
|
api_key,
|
||||||
json_output,
|
json_output,
|
||||||
|
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
if max_batch_size == 0 {
|
if max_batch_size == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (backend, _backend_info) = connect_backend(
|
let (backend, backend_info) = connect_backend(
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
|
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Validate remaining args now that the backend is known
|
||||||
|
let support_chunking = backend_info.support_chunking;
|
||||||
|
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
|
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
trust_remote_code,
|
||||||
hostname,
|
hostname,
|
||||||
port,
|
port,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
ngrok,
|
ngrok,
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats,
|
usage_stats,
|
||||||
|
|
|
@ -4,7 +4,7 @@ use crate::client::{
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::max;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
@ -50,6 +50,7 @@ impl Queue {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
|
@ -62,6 +63,7 @@ impl Queue {
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
|
@ -87,6 +89,10 @@ impl Queue {
|
||||||
prefill_token_budget: u32,
|
prefill_token_budget: u32,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
) -> Option<NextBatch> {
|
) -> Option<NextBatch> {
|
||||||
|
if prefill_token_budget == 0 || token_budget == 0 {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let (response_sender, response_receiver) = oneshot::channel();
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
// Send next batch command to the background task managing the state
|
// Send next batch command to the background task managing the state
|
||||||
|
@ -108,6 +114,7 @@ impl Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
@ -115,6 +122,7 @@ async fn queue_task(
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
|
@ -124,6 +132,7 @@ async fn queue_task(
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
);
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
@ -166,12 +175,14 @@ struct State {
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
/// Sliding window
|
|
||||||
window_size: Option<u32>,
|
|
||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
|
/// Whether the model allow the prefill chunking
|
||||||
|
/// If it does, the last request in the batch will be split to exactly match the prefill
|
||||||
|
/// token budget
|
||||||
|
support_chunking: bool,
|
||||||
|
|
||||||
/// Paged Attention Block Allocation
|
/// Paged Attention Block Allocation
|
||||||
block_allocator: Option<BlockAllocator>,
|
block_allocator: Option<BlockAllocator>,
|
||||||
}
|
}
|
||||||
|
@ -184,6 +195,7 @@ impl State {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = (!requires_padding).then(|| {
|
let block_allocator = (!requires_padding).then(|| {
|
||||||
BlockAllocator::new(
|
BlockAllocator::new(
|
||||||
|
@ -199,8 +211,8 @@ impl State {
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
|
||||||
speculate,
|
speculate,
|
||||||
|
support_chunking,
|
||||||
block_allocator,
|
block_allocator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,32 +299,7 @@ impl State {
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
Some(_block_allocator) => {
|
Some(block_allocator) => {
|
||||||
prefill_tokens += entry.request.input_length;
|
|
||||||
let max_new_tokens = match self.window_size {
|
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
decode_tokens += max_new_tokens;
|
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|
||||||
{
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tokens = entry.request.input_length
|
|
||||||
+ entry.request.stopping_parameters.max_new_tokens
|
|
||||||
+ self.speculate
|
|
||||||
- 1;
|
|
||||||
|
|
||||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
// So no input_ids for the radix tree.
|
// So no input_ids for the radix tree.
|
||||||
let input_ids = if entry.request.decoder_input_details {
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
|
@ -321,10 +308,73 @@ impl State {
|
||||||
entry.request.input_ids.clone()
|
entry.request.input_ids.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
Some((tokens, input_ids))
|
let tokens = entry.request.input_length
|
||||||
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
|
+ self.speculate
|
||||||
|
- 1;
|
||||||
|
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||||
|
|
||||||
|
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
Some(mut block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
|
||||||
|
if block_allocation.prefix_len == entry.request.input_length {
|
||||||
|
// The whole request was found in the radix trie
|
||||||
|
// However, for the transformer forward to work, we need to
|
||||||
|
// have at least one token of postfix.
|
||||||
|
block_allocation.prefix_len -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
block_allocation
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
batch.push((id, entry, block_allocation));
|
|
||||||
|
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||||
|
|
||||||
|
if prefill_tokens + postfix_len > prefill_token_budget {
|
||||||
|
// Entry is over budget
|
||||||
|
if self.support_chunking {
|
||||||
|
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||||
|
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
|
||||||
|
if chunk_len > 0 {
|
||||||
|
// Push this entry inside the batch
|
||||||
|
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
|
||||||
|
} else {
|
||||||
|
// We cannot prefill even one token for this entry
|
||||||
|
// Add it back to the queue
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
|
||||||
|
prefill_tokens + postfix_len
|
||||||
|
);
|
||||||
|
break 'entry_loop;
|
||||||
|
} else {
|
||||||
|
// We don't support chunking, this entry needs to go back to the buffer
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!(
|
||||||
|
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||||
|
prefill_tokens + postfix_len
|
||||||
|
);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill_tokens += postfix_len;
|
||||||
|
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
batch.push((id, entry, block_allocation, None));
|
||||||
if Some(batch.len()) == max_size {
|
if Some(batch.len()) == max_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -342,7 +392,7 @@ impl State {
|
||||||
// Batch is too small
|
// Batch is too small
|
||||||
if batch.len() < min_size {
|
if batch.len() < min_size {
|
||||||
// Add back entries to the queue in the correct order
|
// Add back entries to the queue in the correct order
|
||||||
for (id, entry, _) in batch.into_iter().rev() {
|
for (id, entry, _, _) in batch.into_iter().rev() {
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
}
|
}
|
||||||
return None;
|
return None;
|
||||||
|
@ -353,29 +403,7 @@ impl State {
|
||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
for (id, mut entry, block_allocation) in batch {
|
for (id, mut entry, block_allocation, chunk_len) in batch {
|
||||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
|
||||||
(block_allocation, &self.block_allocator)
|
|
||||||
{
|
|
||||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
|
||||||
match block_allocator.allocate(tokens, input_ids).await {
|
|
||||||
None => {
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Some(block_allocation) => {
|
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
|
||||||
Some(block_allocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
tracing::debug!("Accepting entry");
|
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
// Add relationships
|
// Add relationships
|
||||||
|
@ -427,8 +455,9 @@ impl State {
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
blocks,
|
blocks,
|
||||||
slots,
|
slots,
|
||||||
prefix_len,
|
cache_len: prefix_len,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
|
chunk_len,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
|
@ -436,12 +465,6 @@ impl State {
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
tracing::debug!("Filterered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
|
@ -531,7 +554,7 @@ mod tests {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
input_ids: Some(Arc::new(vec![])),
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 1,
|
||||||
add_special_tokens: true,
|
add_special_tokens: true,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
|
@ -567,7 +590,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_append() {
|
async fn test_append() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
|
@ -583,7 +606,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_empty() {
|
async fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
@ -591,7 +614,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_min_size() {
|
async fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -623,7 +646,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_max_size() {
|
async fn test_next_batch_max_size() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -643,7 +666,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
let mut state = State::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
|
@ -676,14 +699,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
|
@ -691,7 +714,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -724,7 +747,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -740,7 +763,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -765,7 +788,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
let queue = Queue::new(true, 1, false, None, 2, 16, false);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
|
@ -784,7 +807,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
let queue = Queue::new(false, 1, false, None, 0, 16, false);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,8 @@ async fn prefill(
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
cache_len: 0,
|
||||||
|
chunk_len: None,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -173,7 +174,7 @@ async fn prefill(
|
||||||
|
|
||||||
// Run prefill
|
// Run prefill
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
|
||||||
|
|
||||||
// Get latency
|
// Get latency
|
||||||
let latency = start_time.elapsed();
|
let latency = start_time.elapsed();
|
||||||
|
|
|
@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
// Run app
|
// Run app
|
||||||
|
|
|
@ -316,6 +316,98 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/invocations": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate tokens from Sagemaker request",
|
||||||
|
"operationId": "sagemaker_compatibility",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Generated Chat Completion",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"text/event-stream": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/SagemakerStreamResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error",
|
||||||
|
"error_type": "validation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation",
|
||||||
|
"error_type": "generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded",
|
||||||
|
"error_type": "overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation",
|
||||||
|
"error_type": "incomplete_generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/metrics": {
|
"/metrics": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": [
|
"tags": [
|
||||||
|
@ -1865,6 +1957,45 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"SagemakerRequest": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompatGenerateRequest"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatRequest"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompletionRequest"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"SagemakerResponse": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatCompletion"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/CompletionFinal"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"SagemakerStreamResponse": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/ChatCompletionChunk"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/Chunk"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"SimpleToken": {
|
"SimpleToken": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
|
|
@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
|
||||||
|
|
||||||
## Amazon SageMaker
|
## Amazon SageMaker
|
||||||
|
|
||||||
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
|
Amazon Sagemaker natively supports the message API:
|
||||||
|
|
||||||
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import json
|
import json
|
||||||
|
@ -161,12 +159,11 @@ except ValueError:
|
||||||
hub = {
|
hub = {
|
||||||
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
|
||||||
'SM_NUM_GPUS': json.dumps(1),
|
'SM_NUM_GPUS': json.dumps(1),
|
||||||
'MESSAGES_API_ENABLED': True
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# create Hugging Face Model Class
|
# create Hugging Face Model Class
|
||||||
huggingface_model = HuggingFaceModel(
|
huggingface_model = HuggingFaceModel(
|
||||||
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
|
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
|
||||||
env=hub,
|
env=hub,
|
||||||
role=role,
|
role=role,
|
||||||
)
|
)
|
||||||
|
|
|
@ -93,10 +93,10 @@ Options:
|
||||||
## KV_CACHE_DTYPE
|
## KV_CACHE_DTYPE
|
||||||
```shell
|
```shell
|
||||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA
|
||||||
|
|
||||||
[env: KV_CACHE_DTYPE=]
|
[env: KV_CACHE_DTYPE=]
|
||||||
[possible values: fp8_e5m2]
|
[possible values: fp8_e4m3fn, fp8_e5m2]
|
||||||
|
|
||||||
```
|
```
|
||||||
## TRUST_REMOTE_CODE
|
## TRUST_REMOTE_CODE
|
||||||
|
|
|
@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
|
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
|
||||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||||
|
|
|
@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
|
||||||
"max_top_n_tokens": 5,
|
"max_top_n_tokens": 5,
|
||||||
"max_total_tokens": 2048,
|
"max_total_tokens": 2048,
|
||||||
"max_waiting_tokens": 20,
|
"max_waiting_tokens": 20,
|
||||||
"messages_api_enabled": false,
|
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model_type": "Bloom"
|
"model_type": "Bloom"
|
||||||
},
|
},
|
||||||
|
|
|
@ -978,15 +978,16 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1728381423,
|
"lastModified": 1729531056,
|
||||||
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
|
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
|
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
|
"ref": "marlin-kernels-0.3.0",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "text-generation-inference-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
@ -137,6 +137,11 @@
|
||||||
|
|
||||||
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
impure = callPackage ./nix/impure-shell.nix { inherit server; };
|
||||||
|
|
||||||
|
impureWithCuda = callPackage ./nix/impure-shell.nix {
|
||||||
|
inherit server;
|
||||||
|
withCuda = true;
|
||||||
|
};
|
||||||
|
|
||||||
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
|
||||||
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
|
||||||
};
|
};
|
||||||
|
|
|
@ -9,13 +9,16 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import docker
|
import docker
|
||||||
import pytest
|
import pytest
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from syrupy.extensions.json import JSONSnapshotExtension
|
from syrupy.extensions.json import JSONSnapshotExtension
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
|
@ -403,6 +406,7 @@ def launcher(event_loop):
|
||||||
print(" ".join(args), file=sys.stderr)
|
print(" ".join(args), file=sys.stderr)
|
||||||
|
|
||||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||||
|
env["PREFILL_CHUNKING"] = "1"
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
@ -501,6 +505,7 @@ def launcher(event_loop):
|
||||||
|
|
||||||
env = {
|
env = {
|
||||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||||
|
"PREFILL_CHUNKING": "1",
|
||||||
}
|
}
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
@ -642,3 +647,22 @@ def generate_multi():
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
return generate_load_inner
|
return generate_load_inner
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
@pytest.fixture
|
||||||
|
def chicken():
|
||||||
|
path = Path(__file__).parent / "images" / "chicken_on_money.png"
|
||||||
|
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cow_beach():
|
||||||
|
path = Path(__file__).parent / "images" / "cow_beach.png"
|
||||||
|
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
|
@ -11,27 +11,27 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.1875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.93359375,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.875,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1796875,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -39,66 +39,66 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.079956055,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.2763672,
|
"logprob": -0.028808594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37548828,
|
"logprob": -0.013671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4628906,
|
"logprob": -0.69921875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02885437,
|
"logprob": -0.0005874634,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.2565918,
|
"logprob": -0.026855469,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0063438416,
|
"logprob": -0.00020885468,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3056641,
|
"logprob": -0.17773438,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.6035156,
|
"logprob": -0.703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "eos_token",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 3,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 128000,
|
"id": 128000,
|
||||||
|
@ -11,22 +11,22 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -22.96875,
|
"logprob": -18.0,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -10.71875,
|
"logprob": -11.75,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -2.6992188,
|
"logprob": -2.0625,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -4.8398438,
|
"logprob": -6.0,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -34,24 +34,66 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 720,
|
"id": 720,
|
||||||
"logprob": -0.4411621,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \n"
|
"text": " \n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 220,
|
"id": 34564,
|
||||||
"logprob": -0.35864258,
|
"logprob": -0.11279297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 128001,
|
"id": 6975,
|
||||||
|
"logprob": -0.16015625,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 320,
|
||||||
|
"logprob": -0.25195312,
|
||||||
|
"special": false,
|
||||||
|
"text": " ("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16931,
|
||||||
|
"logprob": -1.703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "DL"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": true,
|
"special": false,
|
||||||
"text": "<|end_of_text|>"
|
"text": ")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1207,
|
||||||
|
"logprob": -1.3125,
|
||||||
|
"special": false,
|
||||||
|
"text": " sub"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2630,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "field"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "What is deep learning? \n "
|
"generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,27 +12,27 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.1875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.93359375,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.875,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1796875,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -40,68 +40,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.0047912598,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.025512695,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.012145996,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.72265625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0005760193,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02722168,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00023651123,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.17285156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -116,27 +116,27 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.21875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.95703125,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.9375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1328125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -144,68 +144,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.1796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.02758789,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.013366699,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0004863739,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02709961,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00022506714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.19726562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.77734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -220,27 +220,27 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.21875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.95703125,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.9375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1328125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -248,68 +248,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.1796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.02758789,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.013366699,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0004863739,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02709961,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00022506714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.19726562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.77734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -324,27 +324,27 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3923,
|
"id": 3923,
|
||||||
"logprob": -5.6328125,
|
"logprob": -6.21875,
|
||||||
"text": "What"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -1.2265625,
|
"logprob": -0.95703125,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5655,
|
"id": 5655,
|
||||||
"logprob": -9.1015625,
|
"logprob": -9.9375,
|
||||||
"text": " deep"
|
"text": " deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -1.8085938,
|
"logprob": -1.1328125,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 30,
|
"id": 30,
|
||||||
"logprob": -1.0439453,
|
"logprob": -1.75,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -352,67 +352,67 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 18682,
|
"id": 18682,
|
||||||
"logprob": -2.1992188,
|
"logprob": -1.1796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Deep"
|
"text": " Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.07897949,
|
"logprob": -0.005432129,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 374,
|
||||||
"logprob": -0.27734375,
|
"logprob": -0.02758789,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37402344,
|
"logprob": -0.013366699,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27084,
|
"id": 27084,
|
||||||
"logprob": -1.4511719,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " subset"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 315,
|
"id": 315,
|
||||||
"logprob": -0.02909851,
|
"logprob": -0.0004863739,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5780,
|
"id": 5780,
|
||||||
"logprob": -0.25854492,
|
"logprob": -0.02709961,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " machine"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6975,
|
"id": 6975,
|
||||||
"logprob": -0.0061798096,
|
"logprob": -0.00022506714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " learning"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 430,
|
"id": 430,
|
||||||
"logprob": -1.3046875,
|
"logprob": -0.19726562,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 374,
|
"id": 18065,
|
||||||
"logprob": -1.5537109,
|
"logprob": -0.77734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " involves"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " Deep learning is a subset of machine learning that is"
|
"generated_text": " Deep learning is a subset of machine learning that involves"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,80 +10,95 @@
|
||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1503906,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.5859375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.3945312,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.4555664,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.4777832,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8808594,
|
"logprob": -0.023849487,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37280273,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.26098633,
|
"logprob": -0.14489746,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017137527,
|
"logprob": -0.63183594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2695312,
|
"logprob": -0.010314941,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9238281,
|
"logprob": -0.0635376,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48828125,
|
"logprob": -0.0028572083,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,42 +10,28 @@
|
||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 349,
|
||||||
"logprob": -11.0078125,
|
"logprob": -12.0546875,
|
||||||
"text": "Test"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 3534,
|
||||||
"logprob": -13.59375,
|
"logprob": -10.53125,
|
||||||
"text": "request"
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -2.71875,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -5.0078125,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -0.34838867,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13940,
|
|
||||||
"logprob": -0.38916016,
|
|
||||||
"special": false,
|
|
||||||
"text": "``"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 28832,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "`"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3371,
|
|
||||||
"logprob": -1.2529297,
|
|
||||||
"special": false,
|
|
||||||
"text": "json"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
|
@ -53,37 +39,61 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28751,
|
"id": 23229,
|
||||||
"logprob": 0.0,
|
"logprob": -0.18237305,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "{"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 17504,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " Learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2287,
|
"id": 349,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 345,
|
"id": 264,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \""
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3134,
|
"id": 19804,
|
||||||
"logprob": -0.640625,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "request"
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 302,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13253,
|
||||||
|
"logprob": -0.6040039,
|
||||||
|
"special": false,
|
||||||
|
"text": " Machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17504,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28725,
|
||||||
|
"logprob": -0.11621094,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request\n```json\n{\n \"request"
|
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,82 +11,97 @@
|
||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1503906,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.5859375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.3945312,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.4555664,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.4777832,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13232422,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.023834229,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14416504,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63183594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.064208984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.0028266907,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -100,82 +115,97 @@
|
||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1425781,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.390625,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.45532227,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.48339844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.02420044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14501953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.06427002,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.002817154,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -189,82 +219,97 @@
|
||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1425781,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.390625,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.45532227,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.48339844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.02420044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14501953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.06427002,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.002817154,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -278,81 +323,96 @@
|
||||||
"text": "<s>"
|
"text": "<s>"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 1824,
|
||||||
"logprob": -11.0078125,
|
"logprob": -9.2890625,
|
||||||
"text": "Test"
|
"text": "What"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -13.59375,
|
"logprob": -1.1425781,
|
||||||
"text": "request"
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3534,
|
||||||
|
"logprob": -9.59375,
|
||||||
|
"text": "deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5168,
|
||||||
|
"logprob": -1.390625,
|
||||||
|
"text": "learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.45532227,
|
||||||
|
"text": "?"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.7089844,
|
"logprob": -0.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.68847656,
|
"logprob": -0.48339844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 28771,
|
"id": 23229,
|
||||||
"logprob": -1.9394531,
|
"logprob": -0.13256836,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "#"
|
"text": "Deep"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5168,
|
||||||
"logprob": -2.8828125,
|
"logprob": -0.02420044,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " learning"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 349,
|
||||||
"logprob": -0.37329102,
|
"logprob": -0.13977051,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 264,
|
||||||
"logprob": -0.2602539,
|
"logprob": -0.14501953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 19804,
|
||||||
"logprob": -0.0017185211,
|
"logprob": -0.63134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " subset"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1064,
|
"id": 302,
|
||||||
"logprob": -2.2753906,
|
"logprob": -0.010223389,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "##"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3735,
|
"id": 5599,
|
||||||
"logprob": -1.9316406,
|
"logprob": -0.06427002,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Test"
|
"text": " machine"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2159,
|
"id": 5168,
|
||||||
"logprob": -0.48217773,
|
"logprob": -0.002817154,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " learning"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n\n# Test request\n\n## Test request"
|
"generated_text": "\n\nDeep learning is a subset of machine learning"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -11,32 +11,32 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7133789,
|
"logprob": -0.6201172,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9296875,
|
"logprob": -13.6484375,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.048919678,
|
"logprob": -0.003894806,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6386719,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8105469,
|
"logprob": -6.46875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.84521484,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -44,66 +44,66 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.017028809,
|
"logprob": -0.008979797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0027313232,
|
"logprob": -8.34465e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.0009407997,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0623207e-05,
|
"logprob": -0.0003838539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5361328,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17578125,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.00011539459,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.47436523,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027680397,
|
"logprob": -0.00024354458,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6582031,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.00092840195,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.19470215,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,95 +5,95 @@
|
||||||
"generated_tokens": 10,
|
"generated_tokens": 10,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 338,
|
||||||
"logprob": null,
|
"logprob": null,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16030,
|
||||||
|
"logprob": -13.328125,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -6.4960938,
|
"logprob": -0.24023438,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -5.1484375,
|
"logprob": -3.1386719,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -4.0351562,
|
"logprob": -3.0878906,
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -5.2265625,
|
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 10994,
|
"id": 25584,
|
||||||
"logprob": -1.1542969,
|
|
||||||
"special": false,
|
|
||||||
"text": "Hello"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29991,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "!"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 739,
|
"id": 993,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " It"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2444,
|
"id": 2726,
|
||||||
"logprob": -0.42260742,
|
|
||||||
"special": false,
|
|
||||||
"text": " seems"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 366,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " you"
|
"text": " Des"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29915,
|
"id": 1760,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "'"
|
"text": "cent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 276,
|
"id": 313,
|
||||||
"logprob": -0.9838867,
|
"logprob": -0.12322998,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "re"
|
"text": " ("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3211,
|
"id": 29954,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " address"
|
"text": "G"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 292,
|
"id": 29928,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ing"
|
"text": "D"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 29897,
|
||||||
"logprob": -0.15124512,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": ")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 338,
|
||||||
|
"logprob": -0.6040039,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 385,
|
||||||
|
"logprob": -0.1796875,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
|
"generated_text": "What is gradient descent?\nGradient Descent (GD) is an"
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,32 +12,32 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7133789,
|
"logprob": -0.6201172,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9296875,
|
"logprob": -13.6484375,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.048919678,
|
"logprob": -0.003894806,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6386719,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8105469,
|
"logprob": -6.46875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.84521484,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -45,68 +45,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.017028809,
|
"logprob": -0.008979797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0028476715,
|
"logprob": -8.34465e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023971558,
|
"logprob": -0.00097084045,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0384789e-05,
|
"logprob": -0.0003838539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.23840332,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17602539,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.000116467476,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.47436523,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027871132,
|
"logprob": -0.0002501011,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6582031,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.00092840195,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.18933105,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -121,32 +121,32 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7128906,
|
"logprob": -0.6113281,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9375,
|
"logprob": -13.6640625,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.05053711,
|
"logprob": -0.003929138,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0058594,
|
"logprob": -2.625,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8242188,
|
"logprob": -6.484375,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.84521484,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -154,68 +154,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.018859863,
|
"logprob": -0.009017944,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.002822876,
|
"logprob": -9.536743e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.00097084045,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0384789e-05,
|
"logprob": -0.0003838539,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17126465,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.0001155138,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.47436523,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027036667,
|
"logprob": -0.0002501011,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.0009279251,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.18933105,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -230,32 +230,32 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.71484375,
|
"logprob": -0.609375,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9375,
|
"logprob": -13.671875,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.049346924,
|
"logprob": -0.0040016174,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6230469,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8242188,
|
"logprob": -6.453125,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.86328125,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -263,68 +263,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.017196655,
|
"logprob": -0.008956909,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0028438568,
|
"logprob": -8.34465e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.0009407997,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.026558e-05,
|
"logprob": -0.0003721714,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17602539,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.00011622906,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.48608398,
|
"logprob": -0.010406494,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027894974,
|
"logprob": -0.0002501011,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.00092601776,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.19177246,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -339,32 +339,32 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -0.7192383,
|
"logprob": -0.609375,
|
||||||
"text": "is"
|
"text": "is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16030,
|
"id": 16030,
|
||||||
"logprob": -13.9375,
|
"logprob": -13.6640625,
|
||||||
"text": "gradient"
|
"text": "gradient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.050445557,
|
"logprob": -0.0038967133,
|
||||||
"text": "descent"
|
"text": "descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29973,
|
"id": 29973,
|
||||||
"logprob": -3.0078125,
|
"logprob": -2.6347656,
|
||||||
"text": "?"
|
"text": "?"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -2.8242188,
|
"logprob": -6.453125,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.8276367,
|
"logprob": -6.6875,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -372,67 +372,67 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25584,
|
"id": 25584,
|
||||||
"logprob": -0.01727295,
|
"logprob": -0.008979797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Grad"
|
"text": "Grad"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 993,
|
"id": 993,
|
||||||
"logprob": -0.0027542114,
|
"logprob": -9.536743e-07,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ient"
|
"text": "ient"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26815,
|
"id": 26815,
|
||||||
"logprob": -0.023254395,
|
"logprob": -0.0009407997,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " descent"
|
"text": " descent"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 338,
|
"id": 338,
|
||||||
"logprob": -2.0384789e-05,
|
"logprob": -0.00038409233,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 263,
|
"id": 385,
|
||||||
"logprob": -0.5229492,
|
"logprob": -0.24499512,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " an"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 937,
|
|
||||||
"logprob": -0.17126465,
|
|
||||||
"special": false,
|
|
||||||
"text": " first"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29899,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": "-"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2098,
|
|
||||||
"logprob": -0.00011301041,
|
|
||||||
"special": false,
|
|
||||||
"text": "order"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13883,
|
"id": 13883,
|
||||||
"logprob": -0.48608398,
|
"logprob": -0.010414124,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " optimization"
|
"text": " optimization"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 5687,
|
"id": 5687,
|
||||||
"logprob": -0.00027894974,
|
"logprob": -0.00024354458,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " algorithm"
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 15574,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"special": false,
|
||||||
|
"text": " commonly"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1304,
|
||||||
|
"logprob": -0.0009279251,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 297,
|
||||||
|
"logprob": -0.19470215,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Gradient descent is a first-order optimization algorithm"
|
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -11,57 +11,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -8.9453125,
|
"logprob": -9.0234375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -8.8515625,
|
"logprob": -9.0859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.21875,
|
"logprob": -0.25585938,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -1.2773438,
|
"logprob": -2.1972656,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.25195312,
|
"logprob": -0.2998047,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -4.8203125,
|
"logprob": -5.6445312,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.7734375,
|
"logprob": -3.0839844,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.8310547,
|
"logprob": -0.6748047,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.22766113,
|
"logprob": -0.3864746,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.46240234,
|
"logprob": -0.9355469,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -3.0234375,
|
"logprob": -2.5371094,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -69,7 +69,7 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -0.04626465,
|
"logprob": -1.1679688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
|
|
|
@ -11,57 +11,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -8.9453125,
|
"logprob": -9.015625,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -8.859375,
|
"logprob": -9.0859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.21984863,
|
"logprob": -0.25585938,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -1.2861328,
|
"logprob": -2.2304688,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.25219727,
|
"logprob": -0.29760742,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -4.8007812,
|
"logprob": -5.6796875,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.7949219,
|
"logprob": -3.0742188,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.8046875,
|
"logprob": -0.67626953,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.22424316,
|
"logprob": -0.38842773,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.46191406,
|
"logprob": -0.9165039,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -3.0253906,
|
"logprob": -2.5527344,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -69,7 +69,7 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": 0.0,
|
"logprob": -0.048583984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 259,
|
"id": 259,
|
||||||
"logprob": -0.46948242,
|
"logprob": -0.47070312,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " "
|
"text": " "
|
||||||
},
|
},
|
||||||
|
@ -38,7 +38,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 35622,
|
"id": 35622,
|
||||||
"logprob": -0.79589844,
|
"logprob": -0.796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " cloud"
|
"text": " cloud"
|
||||||
},
|
},
|
||||||
|
@ -75,5 +75,5 @@
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Why is the sky blue?blue sky, clouds and clouds"
|
"generated_text": "Why is the sky blue?blue sky , clouds and clouds"
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,9 @@ import pytest
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_fp8_kv_cache_handle(launcher):
|
def flash_llama_fp8_kv_cache_handle(launcher):
|
||||||
with launcher(
|
with launcher(
|
||||||
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||||
|
num_shard=2,
|
||||||
|
kv_cache_dtype="fp8_e4m3fn",
|
||||||
) as handle:
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== " Deep learning is a subset of machine learning that is"
|
== " Deep learning is a subset of machine learning that involves"
|
||||||
)
|
)
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert (
|
assert (
|
||||||
responses[0].generated_text
|
responses[0].generated_text
|
||||||
== " Deep learning is a subset of machine learning that is"
|
== " Deep learning is a subset of machine learning that involves"
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[r.generated_text == responses[0].generated_text for r in responses]
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
|
|
@ -3,7 +3,11 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_mixtral_gptq_handle(launcher):
|
def flash_mixtral_gptq_handle(launcher):
|
||||||
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
|
with launcher(
|
||||||
|
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ",
|
||||||
|
revision="gptq-4bit-128g-actorder_True",
|
||||||
|
num_shard=2,
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +20,12 @@ async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||||
response = await flash_mixtral_gptq.generate(
|
response = await flash_mixtral_gptq.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text == "\n\nDeep learning is a subset of machine learning"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
@ -25,7 +34,7 @@ async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
||||||
response = await flash_mixtral_gptq.generate(
|
response = await flash_mixtral_gptq.generate(
|
||||||
"Test request",
|
"What is deep learning?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
repetition_penalty=1.2,
|
repetition_penalty=1.2,
|
||||||
return_full_text=True,
|
return_full_text=True,
|
||||||
|
@ -41,6 +50,10 @@ async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapsh
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,10 +62,14 @@ async def test_flash_mixtral_gptq_load(
|
||||||
flash_mixtral_gptq, generate_load, response_snapshot
|
flash_mixtral_gptq, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
|
flash_mixtral_gptq, "What is deep learning?", max_new_tokens=10, n=4
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== "\n\nDeep learning is a subset of machine learning"
|
||||||
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[r.generated_text == responses[0].generated_text for r in responses]
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
), f"{[r.generated_text for r in responses]}"
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
|
||||||
return flash_pali_gemma_handle.client
|
return flash_pali_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
|
||||||
cow = get_cow_beach()
|
inputs = f"![]({cow_beach})Where is the cow standing?\n"
|
||||||
inputs = f"![]({cow})Where is the cow standing?\n"
|
|
||||||
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
|
||||||
|
|
||||||
assert response.generated_text == "beach"
|
assert response.generated_text == "beach"
|
||||||
|
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
|
async def test_flash_pali_gemma_two_images(
|
||||||
chicken = get_chicken()
|
flash_pali_gemma, response_snapshot, chicken, cow_beach
|
||||||
cow_beach = get_cow_beach()
|
):
|
||||||
response = await flash_pali_gemma.generate(
|
response = await flash_pali_gemma.generate(
|
||||||
f"caption![]({chicken})![]({cow_beach})\n",
|
f"caption![]({chicken})![]({cow_beach})\n",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
|
|
@ -25,7 +25,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== "Gradient descent is a first-order optimization algorithm"
|
== "Gradient descent is an optimization algorithm commonly used in"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
||||||
response = await flash_phi35_moe.generate(
|
response = await flash_phi35_moe.generate(
|
||||||
"What is gradient descent?\n\n",
|
"What is gradient descent?\n",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
repetition_penalty=1.2,
|
repetition_penalty=1.2,
|
||||||
return_full_text=True,
|
return_full_text=True,
|
||||||
|
@ -51,7 +51,7 @@ async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== "What is gradient descent?\n\nHello! It seems you're addressing a"
|
== "What is gradient descent?\nGradient Descent (GD) is an"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_sna
|
||||||
assert responses[0].details.generated_tokens == 10
|
assert responses[0].details.generated_tokens == 10
|
||||||
assert (
|
assert (
|
||||||
responses[0].generated_text
|
responses[0].generated_text
|
||||||
== "Gradient descent is a first-order optimization algorithm"
|
== "Gradient descent is an optimization algorithm commonly used in"
|
||||||
)
|
)
|
||||||
assert all(
|
assert all(
|
||||||
[r.generated_text == responses[0].generated_text for r in responses]
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
|
|
@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
|
||||||
|
|
||||||
class Weather(BaseModel):
|
class Weather(BaseModel):
|
||||||
unit: str
|
unit: str
|
||||||
temperature: List[int]
|
temperature: List[int]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
|
||||||
return idefics_handle.client
|
return idefics_handle.client
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics(idefics, response_snapshot):
|
async def test_idefics(idefics, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_idefics_two_images(idefics, response_snapshot):
|
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
|
||||||
chicken = get_chicken()
|
|
||||||
cow_beach = get_cow_beach()
|
|
||||||
response = await idefics.generate(
|
response = await idefics.generate(
|
||||||
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_idefics_load(idefics, generate_load, response_snapshot):
|
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
idefics,
|
idefics,
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
|
|
|
@ -1,18 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_next_simple(
|
||||||
chicken = get_chicken()
|
flash_idefics2_next, response_snapshot, chicken
|
||||||
|
):
|
||||||
response = await flash_idefics2_next.generate(
|
response = await flash_idefics2_next.generate(
|
||||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
|
async def test_flash_idefics2_two_images(
|
||||||
chicken = get_chicken()
|
flash_idefics2_next, response_snapshot, chicken, cow_beach
|
||||||
cow_beach = get_cow_beach()
|
):
|
||||||
response = await flash_idefics2_next.generate(
|
response = await flash_idefics2_next.generate(
|
||||||
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
|
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_idefics2_next_load(
|
async def test_flash_idefics2_next_load(
|
||||||
flash_idefics2_next, generate_load, response_snapshot
|
flash_idefics2_next, generate_load, response_snapshot, chicken
|
||||||
):
|
):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_idefics2_next,
|
flash_idefics2_next,
|
||||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||||
|
|
|
@ -1,12 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
|
||||||
chicken = get_chicken()
|
|
||||||
response = await flash_llava_next.generate(
|
response = await flash_llava_next.generate(
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
max_new_tokens=10,
|
max_new_tokens=10,
|
||||||
|
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llava_next_load(
|
async def test_flash_llava_next_load(
|
||||||
flash_llava_next, generate_load, response_snapshot
|
flash_llava_next, generate_load, response_snapshot, chicken
|
||||||
):
|
):
|
||||||
chicken = get_chicken()
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llava_next,
|
flash_llava_next,
|
||||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import base64
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
|
||||||
return mllama_handle.client
|
return mllama_handle.client
|
||||||
|
|
||||||
|
|
||||||
# TODO fix the server parsser to count inline image tokens correctly
|
|
||||||
def get_chicken():
|
|
||||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cow_beach():
|
|
||||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read())
|
|
||||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mllama_simpl(mllama, response_snapshot):
|
async def test_mllama_simpl(mllama, response_snapshot):
|
||||||
# chicken = get_chicken()
|
|
||||||
response = await mllama.chat(
|
response = await mllama.chat(
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
|
|
@ -68,7 +68,7 @@ fn get_config(
|
||||||
|
|
||||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
let compute_capability = gpu::get_cuda_capability();
|
let compute_capability = gpu::get_cuda_capability();
|
||||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
|
||||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
if let Some(config) = config {
|
if let Some(config) = config {
|
||||||
if prefix_caching.is_none() {
|
if prefix_caching.is_none() {
|
||||||
|
@ -94,7 +94,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
prefix_caching = Some("0".to_string());
|
prefix_caching = Some("0".to_string());
|
||||||
}
|
}
|
||||||
match config.model_type.as_deref() {
|
match config.model_type.as_deref() {
|
||||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
Some("falcon") | Some("deepseek_v2") => {
|
||||||
// Required because gemma2 needs bfloat16 which is not supported by
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
// flashinfer ?
|
// flashinfer ?
|
||||||
if attention.is_none() {
|
if attention.is_none() {
|
||||||
|
@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
|
||||||
|
tracing::info!("Disabling prefix caching on paged attention");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||||
|
@ -303,6 +307,9 @@ impl std::fmt::Display for Dtype {
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum KVCacheDtype {
|
enum KVCacheDtype {
|
||||||
|
#[clap(name = "fp8_e4m3fn")]
|
||||||
|
Fp8e4m3fn,
|
||||||
|
|
||||||
#[clap(name = "fp8_e5m2")]
|
#[clap(name = "fp8_e5m2")]
|
||||||
Fp8e5m2,
|
Fp8e5m2,
|
||||||
}
|
}
|
||||||
|
@ -310,6 +317,9 @@ enum KVCacheDtype {
|
||||||
impl std::fmt::Display for KVCacheDtype {
|
impl std::fmt::Display for KVCacheDtype {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
KVCacheDtype::Fp8e4m3fn => {
|
||||||
|
write!(f, "fp8_e4m3fn")
|
||||||
|
}
|
||||||
KVCacheDtype::Fp8e5m2 => {
|
KVCacheDtype::Fp8e5m2 => {
|
||||||
write!(f, "fp8_e5m2")
|
write!(f, "fp8_e5m2")
|
||||||
}
|
}
|
||||||
|
@ -420,7 +430,7 @@ struct Args {
|
||||||
|
|
||||||
/// Specify the dtype for the key-value cache. When this option is not provided,
|
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||||
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||||
/// the only supported value is `fp8_e5m2` on CUDA.
|
/// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
kv_cache_dtype: Option<KVCacheDtype>,
|
kv_cache_dtype: Option<KVCacheDtype>,
|
||||||
|
|
||||||
|
@ -1094,6 +1104,8 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1497,6 +1509,10 @@ fn spawn_webserver(
|
||||||
router_args.push(revision.to_string())
|
router_args.push(revision.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.trust_remote_code {
|
||||||
|
router_args.push("--trust-remote-code".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
if args.json_output {
|
if args.json_output {
|
||||||
router_args.push("--json-output".to_string());
|
router_args.push("--json-output".to_string());
|
||||||
}
|
}
|
||||||
|
@ -1678,7 +1694,7 @@ fn main() -> Result<(), LauncherError> {
|
||||||
};
|
};
|
||||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||||
std::env::set_var("ATTENTION", attention);
|
std::env::set_var("ATTENTION", attention);
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
|
@ -1729,12 +1745,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_input_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||||
|
@ -1788,12 +1798,6 @@ fn main() -> Result<(), LauncherError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
|
||||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
|
||||||
max_batch_prefill_tokens, max_batch_total_tokens
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
return Err(LauncherError::ArgumentValidation(format!(
|
return Err(LauncherError::ArgumentValidation(format!(
|
||||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
{
|
{
|
||||||
|
lib,
|
||||||
mkShell,
|
mkShell,
|
||||||
black,
|
black,
|
||||||
|
cmake,
|
||||||
isort,
|
isort,
|
||||||
|
ninja,
|
||||||
|
which,
|
||||||
|
cudaPackages,
|
||||||
openssl,
|
openssl,
|
||||||
pkg-config,
|
pkg-config,
|
||||||
protobuf,
|
protobuf,
|
||||||
|
@ -11,14 +16,17 @@
|
||||||
ruff,
|
ruff,
|
||||||
rust-bin,
|
rust-bin,
|
||||||
server,
|
server,
|
||||||
|
|
||||||
|
# Enable dependencies for building CUDA packages. Useful for e.g.
|
||||||
|
# developing marlin/moe-kernels in-place.
|
||||||
|
withCuda ? false,
|
||||||
}:
|
}:
|
||||||
|
|
||||||
mkShell {
|
mkShell {
|
||||||
buildInputs =
|
nativeBuildInputs =
|
||||||
[
|
[
|
||||||
black
|
black
|
||||||
isort
|
isort
|
||||||
openssl.dev
|
|
||||||
pkg-config
|
pkg-config
|
||||||
(rust-bin.stable.latest.default.override {
|
(rust-bin.stable.latest.default.override {
|
||||||
extensions = [
|
extensions = [
|
||||||
|
@ -31,6 +39,19 @@ mkShell {
|
||||||
redocly
|
redocly
|
||||||
ruff
|
ruff
|
||||||
]
|
]
|
||||||
|
++ (lib.optionals withCuda [
|
||||||
|
cmake
|
||||||
|
ninja
|
||||||
|
which
|
||||||
|
|
||||||
|
# For most Torch-based extensions, setting CUDA_HOME is enough, but
|
||||||
|
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
|
||||||
|
cudaPackages.cuda_nvcc
|
||||||
|
]);
|
||||||
|
buildInputs =
|
||||||
|
[
|
||||||
|
openssl.dev
|
||||||
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
venvShellHook
|
venvShellHook
|
||||||
docker
|
docker
|
||||||
|
@ -40,10 +61,29 @@ mkShell {
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
syrupy
|
syrupy
|
||||||
]);
|
])
|
||||||
|
++ (lib.optionals withCuda (
|
||||||
|
with cudaPackages;
|
||||||
|
[
|
||||||
|
cuda_cccl
|
||||||
|
cuda_cudart
|
||||||
|
cuda_nvrtc
|
||||||
|
cuda_nvtx
|
||||||
|
cuda_profiler_api
|
||||||
|
cudnn
|
||||||
|
libcublas
|
||||||
|
libcusolver
|
||||||
|
libcusparse
|
||||||
|
]
|
||||||
|
));
|
||||||
|
|
||||||
inputsFrom = [ server ];
|
inputsFrom = [ server ];
|
||||||
|
|
||||||
|
env = lib.optionalAttrs withCuda {
|
||||||
|
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
|
||||||
|
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
|
||||||
|
};
|
||||||
|
|
||||||
venvDir = "./.venv";
|
venvDir = "./.venv";
|
||||||
|
|
||||||
postVenvCreation = ''
|
postVenvCreation = ''
|
||||||
|
@ -51,6 +91,7 @@ mkShell {
|
||||||
( cd server ; python -m pip install --no-dependencies -e . )
|
( cd server ; python -m pip install --no-dependencies -e . )
|
||||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||||
'';
|
'';
|
||||||
|
|
||||||
postShellHook = ''
|
postShellHook = ''
|
||||||
unset SOURCE_DATE_EPOCH
|
unset SOURCE_DATE_EPOCH
|
||||||
export PATH=$PATH:~/.cargo/bin
|
export PATH=$PATH:~/.cargo/bin
|
||||||
|
|
|
@ -34,6 +34,10 @@ message InfoResponse {
|
||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
uint32 speculate = 5;
|
uint32 speculate = 5;
|
||||||
|
bool support_chunking = 6;
|
||||||
|
bool use_prefix_caching = 7;
|
||||||
|
string attention_impl = 8;
|
||||||
|
uint32 block_size = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -135,10 +139,14 @@ message Request {
|
||||||
repeated uint32 slots = 10;
|
repeated uint32 slots = 10;
|
||||||
/// LORA adapter index
|
/// LORA adapter index
|
||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
/// Prefix length that can be retrieved from the KV cache.
|
/// Tokens that can be retrieved from the KV cache.
|
||||||
uint32 prefix_len = 12;
|
/// This value is set for the first prefill and never reset
|
||||||
|
uint32 cache_len = 12;
|
||||||
/// Context truncation
|
/// Context truncation
|
||||||
bool add_special_tokens = 13;
|
bool add_special_tokens = 13;
|
||||||
|
/// Chunk of tokens that must be computed for the first prefill
|
||||||
|
/// This value is set for the first prefill and never reset
|
||||||
|
optional uint32 chunk_len = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
@ -163,6 +171,8 @@ message CachedBatch {
|
||||||
uint32 size = 3;
|
uint32 size = 3;
|
||||||
/// Maximum number of tokens this batch will grow to
|
/// Maximum number of tokens this batch will grow to
|
||||||
uint32 max_tokens = 4;
|
uint32 max_tokens = 4;
|
||||||
|
/// Number of tokens in the next forward
|
||||||
|
uint32 current_tokens = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum FinishReason {
|
enum FinishReason {
|
||||||
|
@ -220,6 +230,8 @@ message FilterBatchResponse {
|
||||||
message PrefillRequest {
|
message PrefillRequest {
|
||||||
/// Batch
|
/// Batch
|
||||||
Batch batch = 1;
|
Batch batch = 1;
|
||||||
|
/// Optional cached batch
|
||||||
|
CachedBatch cached_batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PrefillResponse {
|
message PrefillResponse {
|
||||||
|
@ -233,6 +245,8 @@ message PrefillResponse {
|
||||||
uint64 decode_ns = 4;
|
uint64 decode_ns = 4;
|
||||||
/// Total elapsed time in nanoseconds
|
/// Total elapsed time in nanoseconds
|
||||||
uint64 total_ns = 5;
|
uint64 total_ns = 5;
|
||||||
|
/// Concatenate elapsed time in nanoseconds
|
||||||
|
optional uint64 concat_ns = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DecodeRequest {
|
message DecodeRequest {
|
||||||
|
|
|
@ -150,6 +150,7 @@ pub enum Config {
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
|
Granite,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
Bloom,
|
Bloom,
|
||||||
Mpt,
|
Mpt,
|
||||||
|
|
|
@ -8,6 +8,7 @@ pub mod validation;
|
||||||
mod kserve;
|
mod kserve;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
|
||||||
|
mod sagemaker;
|
||||||
pub mod usage_stats;
|
pub mod usage_stats;
|
||||||
mod vertex;
|
mod vertex;
|
||||||
|
|
||||||
|
@ -18,45 +19,6 @@ use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
#[derive(PartialEq)]
|
|
||||||
pub enum Attention {
|
|
||||||
Paged,
|
|
||||||
FlashDecoding,
|
|
||||||
FlashInfer,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
pub fn block_size(&self) -> u32 {
|
|
||||||
match self {
|
|
||||||
Attention::FlashDecoding => 256,
|
|
||||||
Attention::FlashInfer => 1,
|
|
||||||
Attention::Paged => 16,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ParseError;
|
|
||||||
|
|
||||||
impl std::fmt::Display for ParseError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "Cannot parse attention value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl std::error::Error for ParseError {}
|
|
||||||
|
|
||||||
impl std::str::FromStr for Attention {
|
|
||||||
type Err = ParseError;
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
||||||
match s {
|
|
||||||
"paged" => Ok(Attention::Paged),
|
|
||||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
|
||||||
"flashinfer" => Ok(Attention::FlashInfer),
|
|
||||||
_ => Err(ParseError),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
|
|
@ -1,748 +0,0 @@
|
||||||
use axum::http::HeaderValue;
|
|
||||||
use clap::Parser;
|
|
||||||
use clap::Subcommand;
|
|
||||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
|
||||||
use hf_hub::{Cache, Repo, RepoType};
|
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
|
||||||
use opentelemetry::sdk::trace;
|
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
|
||||||
use opentelemetry::sdk::Resource;
|
|
||||||
use opentelemetry::{global, KeyValue};
|
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::BufReader;
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
use text_generation_router::config::Config;
|
|
||||||
use text_generation_router::usage_stats;
|
|
||||||
use text_generation_router::{
|
|
||||||
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
|
|
||||||
};
|
|
||||||
use thiserror::Error;
|
|
||||||
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
|
|
||||||
use tower_http::cors::AllowOrigin;
|
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|
||||||
|
|
||||||
/// App Configuration
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[clap(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
#[command(subcommand)]
|
|
||||||
command: Option<Commands>,
|
|
||||||
|
|
||||||
#[clap(default_value = "128", long, env)]
|
|
||||||
max_concurrent_requests: usize,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
max_best_of: usize,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
#[clap(default_value = "5", long, env)]
|
|
||||||
max_top_n_tokens: u32,
|
|
||||||
#[clap(default_value = "1024", long, env)]
|
|
||||||
max_input_tokens: usize,
|
|
||||||
#[clap(default_value = "2048", long, env)]
|
|
||||||
max_total_tokens: usize,
|
|
||||||
#[clap(default_value = "1.2", long, env)]
|
|
||||||
waiting_served_ratio: f32,
|
|
||||||
#[clap(default_value = "4096", long, env)]
|
|
||||||
max_batch_prefill_tokens: u32,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_total_tokens: Option<u32>,
|
|
||||||
#[clap(default_value = "20", long, env)]
|
|
||||||
max_waiting_tokens: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
max_batch_size: Option<usize>,
|
|
||||||
#[clap(default_value = "0.0.0.0", long, env)]
|
|
||||||
hostname: String,
|
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
|
||||||
port: u16,
|
|
||||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
|
||||||
master_shard_uds_path: String,
|
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
|
||||||
tokenizer_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
tokenizer_config_path: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
revision: Option<String>,
|
|
||||||
#[clap(default_value = "2", long, env)]
|
|
||||||
validation_workers: usize,
|
|
||||||
#[clap(long, env)]
|
|
||||||
json_output: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
otlp_endpoint: Option<String>,
|
|
||||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
|
||||||
otlp_service_name: String,
|
|
||||||
#[clap(long, env)]
|
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
api_key: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok: bool,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_authtoken: Option<String>,
|
|
||||||
#[clap(long, env)]
|
|
||||||
ngrok_edge: Option<String>,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
messages_api_enabled: bool,
|
|
||||||
#[clap(long, env, default_value_t = false)]
|
|
||||||
disable_grammar_support: bool,
|
|
||||||
#[clap(default_value = "4", long, env)]
|
|
||||||
max_client_batch_size: usize,
|
|
||||||
#[clap(long, env, default_value_t)]
|
|
||||||
disable_usage_stats: bool,
|
|
||||||
#[clap(long, env, default_value_t)]
|
|
||||||
disable_crash_reports: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Subcommand)]
|
|
||||||
enum Commands {
|
|
||||||
PrintSchema,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), RouterError> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
// Pattern match configuration
|
|
||||||
let Args {
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
hostname,
|
|
||||||
port,
|
|
||||||
master_shard_uds_path,
|
|
||||||
tokenizer_name,
|
|
||||||
tokenizer_config_path,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
json_output,
|
|
||||||
otlp_endpoint,
|
|
||||||
otlp_service_name,
|
|
||||||
cors_allow_origin,
|
|
||||||
api_key,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
disable_usage_stats,
|
|
||||||
disable_crash_reports,
|
|
||||||
command,
|
|
||||||
} = args;
|
|
||||||
|
|
||||||
let print_schema_command = match command {
|
|
||||||
Some(Commands::PrintSchema) => true,
|
|
||||||
None => {
|
|
||||||
// only init logging if we are not running the print schema command
|
|
||||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
|
||||||
false
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate args
|
|
||||||
if max_input_tokens >= max_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
|
||||||
return Err(RouterError::ArgumentValidation(
|
|
||||||
"`validation_workers` must be > 0".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CORS allowed origins
|
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
|
||||||
// Finally, convert to AllowOrigin
|
|
||||||
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
|
||||||
AllowOrigin::list(
|
|
||||||
cors_allow_origin
|
|
||||||
.iter()
|
|
||||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Parse Huggingface hub token
|
|
||||||
let authorization_token = std::env::var("HF_TOKEN")
|
|
||||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
// Tokenizer instance
|
|
||||||
// This will only be used to validate payloads
|
|
||||||
let local_path = Path::new(&tokenizer_name);
|
|
||||||
|
|
||||||
// Shared API builder initialization
|
|
||||||
let api_builder = || {
|
|
||||||
let mut builder = ApiBuilder::new()
|
|
||||||
.with_progress(false)
|
|
||||||
.with_token(authorization_token);
|
|
||||||
|
|
||||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
|
||||||
builder = builder.with_cache_dir(cache_dir.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
builder
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decide if we need to use the API based on the revision and local path
|
|
||||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
|
||||||
|
|
||||||
// Initialize API if needed
|
|
||||||
#[derive(Clone)]
|
|
||||||
enum Type {
|
|
||||||
Api(Api),
|
|
||||||
Cache(Cache),
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
let api = if use_api {
|
|
||||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
|
||||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
|
||||||
.map_err(|_| ())
|
|
||||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
|
||||||
.unwrap_or_else(|_| Cache::default());
|
|
||||||
|
|
||||||
tracing::warn!("Offline mode active using cache defaults");
|
|
||||||
Type::Cache(cache)
|
|
||||||
} else {
|
|
||||||
tracing::info!("Using the Hugging Face API");
|
|
||||||
match api_builder().build() {
|
|
||||||
Ok(api) => Type::Api(api),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Unable to build the Hugging Face API");
|
|
||||||
Type::None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Type::None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load tokenizer and model info
|
|
||||||
let (
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
preprocessor_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
) = match api {
|
|
||||||
Type::None => (
|
|
||||||
Some(local_path.join("tokenizer.json")),
|
|
||||||
Some(local_path.join("config.json")),
|
|
||||||
Some(local_path.join("tokenizer_config.json")),
|
|
||||||
Some(local_path.join("preprocessor_config.json")),
|
|
||||||
Some(local_path.join("processor_config.json")),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
Type::Api(api) => {
|
|
||||||
let api_repo = api.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
|
||||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
|
||||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
|
||||||
};
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok();
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
|
||||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
|
||||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
|
||||||
|
|
||||||
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
|
|
||||||
Some(model_info)
|
|
||||||
} else {
|
|
||||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(
|
|
||||||
tokenizer_filename,
|
|
||||||
config_filename,
|
|
||||||
tokenizer_config_filename,
|
|
||||||
preprocessor_config_filename,
|
|
||||||
processor_config_filename,
|
|
||||||
model_info,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Type::Cache(cache) => {
|
|
||||||
let repo = cache.repo(Repo::with_revision(
|
|
||||||
tokenizer_name.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
|
||||||
));
|
|
||||||
(
|
|
||||||
repo.get("tokenizer.json"),
|
|
||||||
repo.get("config.json"),
|
|
||||||
repo.get("tokenizer_config.json"),
|
|
||||||
repo.get("preprocessor_config.json"),
|
|
||||||
repo.get("processor_config.json"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
|
||||||
std::fs::read_to_string(filename)
|
|
||||||
.ok()
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|c| {
|
|
||||||
let config: Result<Config, _> = serde_json::from_str(c);
|
|
||||||
if let Err(err) = &config {
|
|
||||||
tracing::warn!("Could not parse config {err:?}");
|
|
||||||
}
|
|
||||||
config.ok()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
|
||||||
model_id: tokenizer_name.to_string(),
|
|
||||||
sha: None,
|
|
||||||
pipeline_tag: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
|
||||||
{
|
|
||||||
HubTokenizerConfig::from_file(filename)
|
|
||||||
} else {
|
|
||||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
|
||||||
};
|
|
||||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
|
||||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
|
||||||
HubTokenizerConfig::default()
|
|
||||||
});
|
|
||||||
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
|
|
||||||
|
|
||||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
|
||||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
|
||||||
if let Some(tokenizer) = &mut tokenizer {
|
|
||||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
|
||||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
|
||||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
|
||||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
|
||||||
tokenizer.with_post_processor(post_processor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenizer
|
|
||||||
});
|
|
||||||
|
|
||||||
let preprocessor_config =
|
|
||||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
|
||||||
let processor_config = processor_config_filename
|
|
||||||
.and_then(HubProcessorConfig::from_file)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
tracing::info!("Using config {config:?}");
|
|
||||||
if tokenizer.is_none() {
|
|
||||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
|
||||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
|
||||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
|
||||||
None => {
|
|
||||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
|
||||||
true
|
|
||||||
}
|
|
||||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
|
||||||
};
|
|
||||||
|
|
||||||
// Determine the server port based on the feature and environment variable.
|
|
||||||
let port = if cfg!(feature = "google") {
|
|
||||||
std::env::var("AIP_HTTP_PORT")
|
|
||||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
|
||||||
.unwrap_or(port)
|
|
||||||
} else {
|
|
||||||
port
|
|
||||||
};
|
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
|
||||||
Ok(ip) => SocketAddr::new(ip, port),
|
|
||||||
Err(_) => {
|
|
||||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
|
||||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Only send usage stats when TGI is run in container and the function returns Some
|
|
||||||
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
|
||||||
|
|
||||||
let user_agent = if !disable_usage_stats && is_container {
|
|
||||||
let reduced_args = usage_stats::Args::new(
|
|
||||||
config.clone(),
|
|
||||||
tokenizer_class,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
revision,
|
|
||||||
validation_workers,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
disable_usage_stats,
|
|
||||||
disable_crash_reports,
|
|
||||||
);
|
|
||||||
Some(usage_stats::UserAgent::new(reduced_args))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
let start_event =
|
|
||||||
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
start_event.send().await;
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run server
|
|
||||||
let result = server::run(
|
|
||||||
master_shard_uds_path,
|
|
||||||
model_info,
|
|
||||||
compat_return_full_text,
|
|
||||||
max_concurrent_requests,
|
|
||||||
max_best_of,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_top_n_tokens,
|
|
||||||
max_input_tokens,
|
|
||||||
max_total_tokens,
|
|
||||||
waiting_served_ratio,
|
|
||||||
max_batch_prefill_tokens,
|
|
||||||
max_batch_total_tokens,
|
|
||||||
max_waiting_tokens,
|
|
||||||
max_batch_size,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
validation_workers,
|
|
||||||
addr,
|
|
||||||
cors_allow_origin,
|
|
||||||
api_key,
|
|
||||||
ngrok,
|
|
||||||
ngrok_authtoken,
|
|
||||||
ngrok_edge,
|
|
||||||
tokenizer_config,
|
|
||||||
preprocessor_config,
|
|
||||||
processor_config,
|
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
|
||||||
max_client_batch_size,
|
|
||||||
print_schema_command,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
let stop_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Stop,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
stop_event.send().await;
|
|
||||||
};
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if let Some(ref ua) = user_agent {
|
|
||||||
if !disable_crash_reports {
|
|
||||||
let error_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Error,
|
|
||||||
Some(e.to_string()),
|
|
||||||
);
|
|
||||||
error_event.send().await;
|
|
||||||
} else {
|
|
||||||
let unknow_error_event = usage_stats::UsageStatsEvent::new(
|
|
||||||
ua.clone(),
|
|
||||||
usage_stats::EventType::Error,
|
|
||||||
Some("unknow_error".to_string()),
|
|
||||||
);
|
|
||||||
unknow_error_event.send().await;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Err(RouterError::WebServer(e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
|
||||||
/// - otlp_service_name service name to appear in APM
|
|
||||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
|
||||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
|
||||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
|
||||||
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
|
||||||
let mut layers = Vec::new();
|
|
||||||
|
|
||||||
// STDOUT/STDERR layer
|
|
||||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
|
||||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
||||||
.with_file(true)
|
|
||||||
.with_ansi(ansi)
|
|
||||||
.with_line_number(true);
|
|
||||||
|
|
||||||
let fmt_layer = match json_output {
|
|
||||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
|
||||||
false => fmt_layer.boxed(),
|
|
||||||
};
|
|
||||||
layers.push(fmt_layer);
|
|
||||||
|
|
||||||
// OpenTelemetry tracing layer
|
|
||||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
|
||||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
|
||||||
|
|
||||||
let tracer = opentelemetry_otlp::new_pipeline()
|
|
||||||
.tracing()
|
|
||||||
.with_exporter(
|
|
||||||
opentelemetry_otlp::new_exporter()
|
|
||||||
.tonic()
|
|
||||||
.with_endpoint(otlp_endpoint),
|
|
||||||
)
|
|
||||||
.with_trace_config(
|
|
||||||
trace::config()
|
|
||||||
.with_resource(Resource::new(vec![KeyValue::new(
|
|
||||||
"service.name",
|
|
||||||
otlp_service_name,
|
|
||||||
)]))
|
|
||||||
.with_sampler(Sampler::AlwaysOn),
|
|
||||||
)
|
|
||||||
.install_batch(opentelemetry::runtime::Tokio);
|
|
||||||
|
|
||||||
if let Ok(tracer) = tracer {
|
|
||||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
|
||||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter events with LOG_LEVEL
|
|
||||||
let varname = "LOG_LEVEL";
|
|
||||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
|
||||||
// Override to avoid simple logs to be spammed with tokio level informations
|
|
||||||
let log_level = match &log_level[..] {
|
|
||||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
|
||||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
|
||||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
|
||||||
log_level => log_level,
|
|
||||||
};
|
|
||||||
EnvFilter::builder()
|
|
||||||
.with_default_directive(LevelFilter::INFO.into())
|
|
||||||
.parse_lossy(log_level)
|
|
||||||
} else {
|
|
||||||
EnvFilter::new("info")
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
|
||||||
.with(env_filter)
|
|
||||||
.with(layers)
|
|
||||||
.init();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get model info from the Huggingface Hub
|
|
||||||
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
|
||||||
let response = api.info_request().send().await.ok()?;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let hub_model_info: HubModelInfo =
|
|
||||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
|
||||||
if let Some(sha) = &hub_model_info.sha {
|
|
||||||
tracing::info!(
|
|
||||||
"Serving revision {sha} of model {}",
|
|
||||||
hub_model_info.model_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Some(hub_model_info)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get base tokenizer
|
|
||||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
|
||||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of `User`.
|
|
||||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
|
||||||
|
|
||||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
|
||||||
let api_base_repo = api.repo(Repo::with_revision(
|
|
||||||
base_model_id.to_string(),
|
|
||||||
RepoType::Model,
|
|
||||||
"main".to_string(),
|
|
||||||
));
|
|
||||||
|
|
||||||
api_base_repo.get("tokenizer.json").await.ok()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// get tokenizer_config from the Huggingface Hub
|
|
||||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
|
||||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
|
||||||
|
|
||||||
// Open the file in read-only mode with buffer.
|
|
||||||
let file = File::open(tokenizer_config_filename).ok()?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
|
|
||||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
|
||||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
|
||||||
e
|
|
||||||
})
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(tokenizer_config)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a post_processor for the LlamaTokenizer
|
|
||||||
pub fn create_post_processor(
|
|
||||||
tokenizer: &Tokenizer,
|
|
||||||
tokenizer_config: &HubTokenizerConfig,
|
|
||||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
|
||||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
|
||||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
|
||||||
|
|
||||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
|
||||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
|
||||||
|
|
||||||
if add_bos_token && bos_token.is_none() {
|
|
||||||
panic!("add_bos_token = true but bos_token is None");
|
|
||||||
}
|
|
||||||
|
|
||||||
if add_eos_token && eos_token.is_none() {
|
|
||||||
panic!("add_eos_token = true but eos_token is None");
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut single = Vec::new();
|
|
||||||
let mut pair = Vec::new();
|
|
||||||
let mut special_tokens = Vec::new();
|
|
||||||
|
|
||||||
if add_bos_token {
|
|
||||||
if let Some(bos) = bos_token {
|
|
||||||
let bos_token_id = tokenizer
|
|
||||||
.token_to_id(bos.as_str())
|
|
||||||
.expect("Should have found the bos token id");
|
|
||||||
special_tokens.push((bos.as_str(), bos_token_id));
|
|
||||||
single.push(format!("{}:0", bos.as_str()));
|
|
||||||
pair.push(format!("{}:0", bos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
single.push("$A:0".to_string());
|
|
||||||
pair.push("$A:0".to_string());
|
|
||||||
|
|
||||||
if add_eos_token {
|
|
||||||
if let Some(eos) = eos_token {
|
|
||||||
let eos_token_id = tokenizer
|
|
||||||
.token_to_id(eos.as_str())
|
|
||||||
.expect("Should have found the eos token id");
|
|
||||||
special_tokens.push((eos.as_str(), eos_token_id));
|
|
||||||
single.push(format!("{}:0", eos.as_str()));
|
|
||||||
pair.push(format!("{}:0", eos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if add_bos_token {
|
|
||||||
if let Some(bos) = bos_token {
|
|
||||||
pair.push(format!("{}:1", bos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pair.push("$B:1".to_string());
|
|
||||||
|
|
||||||
if add_eos_token {
|
|
||||||
if let Some(eos) = eos_token {
|
|
||||||
pair.push(format!("{}:1", eos.as_str()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let post_processor = TemplateProcessing::builder()
|
|
||||||
.try_single(single)?
|
|
||||||
.try_pair(pair)?
|
|
||||||
.special_tokens(special_tokens)
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(post_processor)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
enum RouterError {
|
|
||||||
#[error("Argument validation error: {0}")]
|
|
||||||
ArgumentValidation(String),
|
|
||||||
#[error("WebServer error: {0}")]
|
|
||||||
WebServer(#[from] server::WebServerError),
|
|
||||||
#[error("Tokio runtime failed to start: {0}")]
|
|
||||||
Tokio(#[from] std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use text_generation_router::TokenizerConfigToken;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_create_post_processor() {
|
|
||||||
let tokenizer_config = HubTokenizerConfig {
|
|
||||||
add_bos_token: None,
|
|
||||||
add_eos_token: None,
|
|
||||||
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
|
|
||||||
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
|
|
||||||
chat_template: None,
|
|
||||||
tokenizer_class: None,
|
|
||||||
completion_template: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let tokenizer =
|
|
||||||
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
|
|
||||||
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
|
|
||||||
|
|
||||||
let expected = TemplateProcessing::builder()
|
|
||||||
.try_single("<s>:0 $A:0")
|
|
||||||
.unwrap()
|
|
||||||
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
|
|
||||||
.unwrap()
|
|
||||||
.special_tokens(vec![("<s>".to_string(), 1)])
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(post_processor, expected);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
use crate::infer::Infer;
|
||||||
|
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
|
||||||
|
use crate::{
|
||||||
|
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
|
||||||
|
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
|
||||||
|
};
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::Response;
|
||||||
|
use axum::Json;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerRequest {
|
||||||
|
Generate(CompatGenerateRequest),
|
||||||
|
Chat(ChatRequest),
|
||||||
|
Completion(CompletionRequest),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for OpenAPI specs
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerResponse {
|
||||||
|
Generate(GenerateResponse),
|
||||||
|
Chat(ChatCompletion),
|
||||||
|
Completion(CompletionFinal),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used for OpenAPI specs
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum SagemakerStreamResponse {
|
||||||
|
Generate(StreamResponse),
|
||||||
|
Chat(ChatCompletionChunk),
|
||||||
|
Completion(Chunk),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Sagemaker request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/invocations",
|
||||||
|
request_body = SagemakerRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Chat Completion",
|
||||||
|
content(
|
||||||
|
("application/json" = SagemakerResponse),
|
||||||
|
("text/event-stream" = SagemakerStreamResponse),
|
||||||
|
)),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) async fn sagemaker_compatibility(
|
||||||
|
default_return_full_text: Extension<bool>,
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
compute_type: Extension<ComputeType>,
|
||||||
|
info: Extension<Info>,
|
||||||
|
Json(req): Json<SagemakerRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
match req {
|
||||||
|
SagemakerRequest::Generate(req) => {
|
||||||
|
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
|
||||||
|
}
|
||||||
|
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
|
||||||
|
SagemakerRequest::Completion(req) => {
|
||||||
|
completions(infer, compute_type, info, Json(req)).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,10 @@ use crate::kserve::{
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::logging::trace_context_middleware;
|
use crate::logging::trace_context_middleware;
|
||||||
|
use crate::sagemaker::{
|
||||||
|
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
|
||||||
|
__path_sagemaker_compatibility,
|
||||||
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::vertex::vertex_compatibility;
|
use crate::vertex::vertex_compatibility;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::ChatTokenizeResponse;
|
||||||
|
@ -85,7 +89,7 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(skip(infer, req))]
|
#[instrument(skip(infer, req))]
|
||||||
async fn compat_generate(
|
pub(crate) async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
|
@ -694,7 +698,7 @@ time_per_token,
|
||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn completions(
|
pub(crate) async fn completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
@ -1223,7 +1227,7 @@ time_per_token,
|
||||||
seed,
|
seed,
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
async fn chat_completions(
|
pub(crate) async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
@ -1539,11 +1543,13 @@ completions,
|
||||||
tokenize,
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
openai_get_model_info,
|
openai_get_model_info,
|
||||||
|
sagemaker_compatibility,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
Info,
|
Info,
|
||||||
CompatGenerateRequest,
|
CompatGenerateRequest,
|
||||||
|
SagemakerRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GrammarType,
|
GrammarType,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
|
@ -1566,6 +1572,8 @@ ChatCompletionTopLogprob,
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionComplete,
|
CompletionComplete,
|
||||||
|
SagemakerResponse,
|
||||||
|
SagemakerStreamResponse,
|
||||||
Chunk,
|
Chunk,
|
||||||
Completion,
|
Completion,
|
||||||
CompletionFinal,
|
CompletionFinal,
|
||||||
|
@ -1627,13 +1635,13 @@ pub async fn run(
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
trust_remote_code: bool,
|
||||||
hostname: String,
|
hostname: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: usage_stats::UsageStatsLevel,
|
usage_stats_level: usage_stats::UsageStatsLevel,
|
||||||
|
@ -1787,10 +1795,13 @@ pub async fn run(
|
||||||
let auto = transformers.getattr("AutoTokenizer")?;
|
let auto = transformers.getattr("AutoTokenizer")?;
|
||||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||||
let args = (tokenizer_name.to_string(),);
|
let args = (tokenizer_name.to_string(),);
|
||||||
let kwargs = [(
|
let kwargs = [
|
||||||
|
(
|
||||||
"revision",
|
"revision",
|
||||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
|
||||||
)]
|
),
|
||||||
|
("trust_remote_code", trust_remote_code.into_py(py)),
|
||||||
|
]
|
||||||
.into_py_dict_bound(py);
|
.into_py_dict_bound(py);
|
||||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||||
let save = tokenizer.getattr("save_pretrained")?;
|
let save = tokenizer.getattr("save_pretrained")?;
|
||||||
|
@ -1862,7 +1873,6 @@ pub async fn run(
|
||||||
// max_batch_size,
|
// max_batch_size,
|
||||||
revision.clone(),
|
revision.clone(),
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
@ -1904,7 +1914,6 @@ pub async fn run(
|
||||||
ngrok,
|
ngrok,
|
||||||
_ngrok_authtoken,
|
_ngrok_authtoken,
|
||||||
_ngrok_edge,
|
_ngrok_edge,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
model_info,
|
model_info,
|
||||||
|
@ -1964,7 +1973,6 @@ async fn start(
|
||||||
ngrok: bool,
|
ngrok: bool,
|
||||||
_ngrok_authtoken: Option<String>,
|
_ngrok_authtoken: Option<String>,
|
||||||
_ngrok_edge: Option<String>,
|
_ngrok_edge: Option<String>,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
model_info: HubModelInfo,
|
model_info: HubModelInfo,
|
||||||
|
@ -2279,6 +2287,7 @@ async fn start(
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
.route("/v1/chat/completions", post(chat_completions))
|
||||||
.route("/v1/completions", post(completions))
|
.route("/v1/completions", post(completions))
|
||||||
.route("/vertex", post(vertex_compatibility))
|
.route("/vertex", post(vertex_compatibility))
|
||||||
|
.route("/invocations", post(sagemaker_compatibility))
|
||||||
.route("/tokenize", post(tokenize));
|
.route("/tokenize", post(tokenize));
|
||||||
|
|
||||||
if let Some(api_key) = api_key {
|
if let Some(api_key) = api_key {
|
||||||
|
@ -2314,13 +2323,6 @@ async fn start(
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.route("/v1/models", get(openai_get_model_info));
|
.route("/v1/models", get(openai_get_model_info));
|
||||||
|
|
||||||
// Conditional AWS Sagemaker route
|
|
||||||
let aws_sagemaker_route = if messages_api_enabled {
|
|
||||||
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
|
|
||||||
} else {
|
|
||||||
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
|
|
||||||
};
|
|
||||||
|
|
||||||
let compute_type =
|
let compute_type =
|
||||||
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
|
||||||
|
|
||||||
|
@ -2328,8 +2330,7 @@ async fn start(
|
||||||
let mut app = Router::new()
|
let mut app = Router::new()
|
||||||
.merge(swagger_ui)
|
.merge(swagger_ui)
|
||||||
.merge(base_routes)
|
.merge(base_routes)
|
||||||
.merge(info_routes)
|
.merge(info_routes);
|
||||||
.merge(aws_sagemaker_route);
|
|
||||||
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
|
|
|
@ -93,7 +93,6 @@ pub struct Args {
|
||||||
// max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
|
@ -117,7 +116,6 @@ impl Args {
|
||||||
// max_batch_size: Option<usize>,
|
// max_batch_size: Option<usize>,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
messages_api_enabled: bool,
|
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
max_client_batch_size: usize,
|
max_client_batch_size: usize,
|
||||||
usage_stats_level: UsageStatsLevel,
|
usage_stats_level: UsageStatsLevel,
|
||||||
|
@ -138,7 +136,6 @@ impl Args {
|
||||||
// max_batch_size,
|
// max_batch_size,
|
||||||
revision,
|
revision,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
messages_api_enabled,
|
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
max_client_batch_size,
|
max_client_batch_size,
|
||||||
usage_stats_level,
|
usage_stats_level,
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
use crate::infer::Infer;
|
use crate::infer::Infer;
|
||||||
use crate::server::{generate_internal, ComputeType};
|
use crate::server::{generate_internal, ComputeType};
|
||||||
use crate::{
|
use crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest};
|
||||||
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
|
|
||||||
StreamOptions, Tool, ToolChoice,
|
|
||||||
};
|
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
use axum::response::{IntoResponse, Response};
|
use axum::response::{IntoResponse, Response};
|
||||||
|
@ -22,162 +19,12 @@ pub(crate) struct GenerateVertexInstance {
|
||||||
pub parameters: Option<GenerateParameters>,
|
pub parameters: Option<GenerateParameters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
||||||
pub(crate) struct VertexChat {
|
|
||||||
messages: Vec<Message>,
|
|
||||||
// Messages is ignored there.
|
|
||||||
#[serde(default)]
|
|
||||||
parameters: VertexParameters,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
|
||||||
pub(crate) struct VertexParameters {
|
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
|
||||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
|
||||||
pub model: Option<String>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
|
||||||
/// decreasing the model's likelihood to repeat the same line verbatim.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "1.0")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
|
||||||
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
|
||||||
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
|
||||||
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
|
||||||
/// result in a ban or exclusive selection of the relevant token.
|
|
||||||
#[serde(default)]
|
|
||||||
pub logit_bias: Option<Vec<f32>>,
|
|
||||||
|
|
||||||
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
|
||||||
/// output token returned in the content of message.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "false")]
|
|
||||||
pub logprobs: Option<bool>,
|
|
||||||
|
|
||||||
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
|
||||||
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "5")]
|
|
||||||
pub top_logprobs: Option<u32>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(example = "32")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// UNUSED
|
|
||||||
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
|
||||||
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "2")]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
|
||||||
/// increasing the model's likelihood to talk about new topics
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = 0.1)]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub stop: Option<Vec<String>>,
|
|
||||||
|
|
||||||
#[serde(default = "bool::default")]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
#[schema(nullable = true, example = 42)]
|
|
||||||
pub seed: Option<u64>,
|
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
|
||||||
/// lower values like 0.2 will make it more focused and deterministic.
|
|
||||||
///
|
|
||||||
/// We generally recommend altering this or `top_p` but not both.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = 1.0)]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
|
||||||
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = 0.95)]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
|
||||||
/// functions the model may generate JSON inputs for.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
|
|
||||||
/// A prompt to be appended before the tools
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(
|
|
||||||
nullable = true,
|
|
||||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
|
||||||
)]
|
|
||||||
pub tool_prompt: Option<String>,
|
|
||||||
|
|
||||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub tool_choice: ToolChoice,
|
|
||||||
|
|
||||||
/// Response format constraints for the generation.
|
|
||||||
///
|
|
||||||
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
|
||||||
pub response_format: Option<GrammarType>,
|
|
||||||
|
|
||||||
/// A guideline to be used in the chat_template
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
|
||||||
pub guideline: Option<String>,
|
|
||||||
|
|
||||||
/// Options for streaming response. Only set this when you set stream: true.
|
|
||||||
#[serde(default)]
|
|
||||||
#[schema(nullable = true, example = "null")]
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<VertexChat> for ChatRequest {
|
|
||||||
fn from(val: VertexChat) -> Self {
|
|
||||||
Self {
|
|
||||||
messages: val.messages,
|
|
||||||
frequency_penalty: val.parameters.frequency_penalty,
|
|
||||||
guideline: val.parameters.guideline,
|
|
||||||
logit_bias: val.parameters.logit_bias,
|
|
||||||
logprobs: val.parameters.logprobs,
|
|
||||||
max_tokens: val.parameters.max_tokens,
|
|
||||||
model: val.parameters.model,
|
|
||||||
n: val.parameters.n,
|
|
||||||
presence_penalty: val.parameters.presence_penalty,
|
|
||||||
response_format: val.parameters.response_format,
|
|
||||||
seed: val.parameters.seed,
|
|
||||||
stop: val.parameters.stop,
|
|
||||||
stream_options: val.parameters.stream_options,
|
|
||||||
stream: val.parameters.stream,
|
|
||||||
temperature: val.parameters.temperature,
|
|
||||||
tool_choice: val.parameters.tool_choice,
|
|
||||||
tool_prompt: val.parameters.tool_prompt,
|
|
||||||
tools: val.parameters.tools,
|
|
||||||
top_logprobs: val.parameters.top_logprobs,
|
|
||||||
top_p: val.parameters.top_p,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub(crate) enum VertexInstance {
|
pub(crate) enum VertexInstance {
|
||||||
Generate(GenerateVertexInstance),
|
Generate(GenerateVertexInstance),
|
||||||
Chat(VertexChat),
|
Chat(ChatRequest),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
@ -263,9 +110,8 @@ pub(crate) async fn vertex_compatibility(
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
VertexInstance::Chat(instance) => {
|
VertexInstance::Chat(instance) => {
|
||||||
let chat_request: ChatRequest = instance.into();
|
|
||||||
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
||||||
chat_request.try_into_generate(&infer)?;
|
instance.try_into_generate(&infer)?;
|
||||||
generate_request
|
generate_request
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -311,35 +157,15 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn vertex_deserialization() {
|
fn vertex_deserialization() {
|
||||||
let string = serde_json::json!({
|
|
||||||
|
|
||||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
|
||||||
"parameters": {
|
|
||||||
"max_tokens": 128,
|
|
||||||
"top_p": 0.95,
|
|
||||||
"temperature": 0.7
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
|
||||||
|
|
||||||
let string = serde_json::json!({
|
|
||||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
|
||||||
});
|
|
||||||
|
|
||||||
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
|
||||||
|
|
||||||
let string = serde_json::json!({
|
let string = serde_json::json!({
|
||||||
|
|
||||||
"instances": [
|
"instances": [
|
||||||
{
|
{
|
||||||
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
"parameters": {
|
|
||||||
"max_tokens": 128,
|
"max_tokens": 128,
|
||||||
"top_p": 0.95,
|
"top_p": 0.95,
|
||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
}
|
}
|
||||||
}
|
|
||||||
]
|
]
|
||||||
|
|
||||||
});
|
});
|
||||||
|
@ -347,18 +173,16 @@ mod tests {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request,
|
request,
|
||||||
VertexRequest {
|
VertexRequest {
|
||||||
instances: vec![VertexInstance::Chat(VertexChat {
|
instances: vec![VertexInstance::Chat(ChatRequest {
|
||||||
messages: vec![Message {
|
messages: vec![Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
},],
|
},],
|
||||||
parameters: VertexParameters {
|
|
||||||
max_tokens: Some(128),
|
max_tokens: Some(128),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
|
||||||
})]
|
})]
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
|
@ -31,7 +31,7 @@ install: install-cuda
|
||||||
echo "Installed server"
|
echo "Installed server"
|
||||||
|
|
||||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||||
pip install -e ".[bnb]"
|
pip install -e ".[bnb,marlin,moe]"
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
|
||||||
numpy = "^1.26"
|
numpy = "^1.26"
|
||||||
|
|
||||||
marlin-kernels = [
|
marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
moe-kernels = [
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
|
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
||||||
import os
|
import os
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
os.environ["USE_PREFIX_CACHING"] = "1"
|
os.environ["PREFIX_CACHING"] = "1"
|
||||||
os.environ["ATTENTION"] = "flashinfer"
|
os.environ["ATTENTION"] = "flashinfer"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ class Dtype(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class KVCacheDtype(str, Enum):
|
class KVCacheDtype(str, Enum):
|
||||||
|
fp8_e4m3fn = "fp8_e4m3fn"
|
||||||
fp8_e5m2 = "fp8_e5m2"
|
fp8_e5m2 = "fp8_e5m2"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,9 @@ from typing import Callable, Any
|
||||||
|
|
||||||
|
|
||||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
def __init__(self, shutdown_callback):
|
||||||
|
self.shutdown_callback = shutdown_callback
|
||||||
|
|
||||||
async def intercept(
|
async def intercept(
|
||||||
self,
|
self,
|
||||||
method: Callable,
|
method: Callable,
|
||||||
|
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||||
|
|
||||||
# Runtime Error cannot be recovered from
|
# Runtime Error cannot be recovered from
|
||||||
if isinstance(err, RuntimeError):
|
if isinstance(err, RuntimeError):
|
||||||
exit(1)
|
self.shutdown_callback()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
@ -8,39 +8,32 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import (
|
from .cuda import (
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import (
|
from .rocm import (
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
from .ipex import (
|
from .ipex import (
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
||||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||||
from .kv_cache import KVCache
|
from .kv_cache import KVCache, get_kv_scales
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
|
"get_kv_scales",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
|
|
|
@ -1,16 +1,12 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
@dataclass
|
||||||
|
class Seqlen:
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
prefix_lengths: torch.Tensor
|
cache_lengths: torch.Tensor
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
max_q: int
|
max_q: int
|
||||||
|
@ -19,13 +15,13 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
prefix_lengths,
|
cache_lengths,
|
||||||
cu_seqlen_q=None,
|
cu_seqlen_q=None,
|
||||||
max_q=None,
|
max_q=None,
|
||||||
max_k=None,
|
max_k=None,
|
||||||
):
|
):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
self.prefix_lengths = prefix_lengths
|
self.cache_lengths = cache_lengths
|
||||||
device = self.input_lengths.device
|
device = self.input_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.input_lengths.shape
|
||||||
if cu_seqlen_q is None:
|
if cu_seqlen_q is None:
|
||||||
|
@ -43,7 +39,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
# Although FA2 might not want the clamping
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
total = self.input_lengths + self.prefix_lengths
|
total = self.input_lengths + self.cache_lengths
|
||||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
|
@ -54,19 +50,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
return self
|
return self
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seqlen:
|
|
||||||
input_lengths: torch.Tensor
|
|
||||||
prefix_lengths: torch.Tensor
|
|
||||||
cu_seqlen_q: torch.Tensor
|
|
||||||
max_q: int
|
|
||||||
max_k: int
|
|
||||||
|
|
||||||
def clamp(self, max):
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
return self
|
|
||||||
self.input_lengths = torch.clamp(self.input_lengths, max=max)
|
|
||||||
return self
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
|
@ -7,44 +8,22 @@ from text_generation_server.models.globals import (
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: KVCache,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
@ -70,6 +49,8 @@ def paged_attention(
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
|
||||||
|
can_scale = kv_cache.can_scale(kv_scales)
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
@ -79,10 +60,13 @@ def paged_attention(
|
||||||
from text_generation_server.layers.attention.flashinfer import decode_state
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
|
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
paged_kv_cache=(key_cache, value_cache),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
)
|
)
|
||||||
elif ATTENTION == "flashdecoding":
|
elif ATTENTION == "flashdecoding":
|
||||||
max_q = 1
|
max_q = 1
|
||||||
|
@ -98,8 +82,8 @@ def paged_attention(
|
||||||
softcap = 0.0
|
softcap = 0.0
|
||||||
out = flash_attn_2_cuda.varlen_fwd(
|
out = flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
None,
|
None,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
|
@ -123,7 +107,7 @@ def paged_attention(
|
||||||
else:
|
else:
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
@ -135,8 +119,8 @@ def paged_attention(
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -168,8 +152,8 @@ def paged_attention(
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -216,60 +200,69 @@ except ImportError:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
if ATTENTION == "flashdecoding" and not V2:
|
||||||
|
raise ValueError("Flash decoding requires Flash Attention V2")
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = V2
|
SUPPORTS_WINDOWING = V2
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
*,
|
||||||
key_cache: torch.Tensor,
|
query: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale: float,
|
||||||
window_size_left=-1,
|
window_size_left: int = -1,
|
||||||
causal=True,
|
causal: bool = True,
|
||||||
softcap=0.0,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
can_scale = kv_cache.can_scale(kv_scales)
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
prefill_with_paged_kv_state,
|
prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
|
||||||
return prefill_with_paged_kv_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
q.contiguous(),
|
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
||||||
|
query.contiguous(),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
paged_kv_cache=(key_cache, value_cache),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
|
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
|
||||||
|
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif V2:
|
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||||
|
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
|
||||||
def attention(
|
elif V2:
|
||||||
q,
|
out = torch.empty_like(query)
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap=0.0,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
query,
|
||||||
key_cache,
|
# flashdecoding: pass the KV caches, paged: pass the KV.
|
||||||
value_cache,
|
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
||||||
|
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables if ATTENTION == "flashdecoding" else None,
|
||||||
None,
|
None,
|
||||||
seqlen.max_q,
|
seqlen.max_q,
|
||||||
seqlen.max_k,
|
seqlen.max_k,
|
||||||
|
@ -284,57 +277,45 @@ elif V2:
|
||||||
None,
|
None,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap=None,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
)
|
)
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
raise NotImplementedError("softcap is not available in flash attn v1")
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if key.shape[1] != query.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
if k.shape[1] == 1:
|
if key.shape[1] == 1:
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
key = key.expand(-1, query.shape[1], -1)
|
||||||
# Grouped attention reshape
|
# Grouped attention reshape
|
||||||
else:
|
else:
|
||||||
original_shape = k.shape
|
original_shape = key.shape
|
||||||
k = (
|
key = (
|
||||||
k.unsqueeze(2)
|
key.unsqueeze(2)
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
)
|
)
|
||||||
if v.shape[1] != q.shape[1]:
|
if value.shape[1] != query.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
if v.shape[1] == 1:
|
if value.shape[1] == 1:
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
value = value.expand(-1, query.shape[1], -1)
|
||||||
# Grouped attention reshape
|
# Grouped attention reshape
|
||||||
else:
|
else:
|
||||||
original_shape = v.shape
|
original_shape = value.shape
|
||||||
v = (
|
value = (
|
||||||
v.unsqueeze(2)
|
value.unsqueeze(2)
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
)
|
)
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(query)
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
q,
|
query,
|
||||||
k,
|
key,
|
||||||
v,
|
value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -351,15 +332,8 @@ else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
# Prefill in the cache with every kind of attention, unless we
|
|
||||||
# have a configuration that requires flash-attention v1, which
|
|
||||||
# does not support block tables.
|
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -699,7 +699,6 @@ def check_args(
|
||||||
|
|
||||||
|
|
||||||
class _attention(torch.autograd.Function):
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(
|
||||||
ctx,
|
ctx,
|
||||||
|
|
|
@ -204,6 +204,7 @@ def use_decode_state(
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
kv_cache_dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
window_left: int,
|
window_left: int,
|
||||||
):
|
):
|
||||||
|
@ -240,7 +241,7 @@ def use_decode_state(
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_size,
|
head_dim=head_size,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
data_type=dtype,
|
data_type=kv_cache_dtype,
|
||||||
q_data_type=dtype,
|
q_data_type=dtype,
|
||||||
window_left=window_left,
|
window_left=window_left,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,31 +1,37 @@
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
PREFILL_IN_KV_CACHE = False
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
*,
|
||||||
key_cache: torch.Tensor,
|
query: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale: float,
|
||||||
window_size_left=-1,
|
window_size_left: int = -1,
|
||||||
causal=True,
|
causal: bool = True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(q)
|
if softcap is not None:
|
||||||
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
q.contiguous() if q.device.type == "xpu" else q,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
value.contiguous() if value.device.type == "xpu" else value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -42,39 +48,32 @@ def attention(
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: KVCache,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
seqlen.input_lengths,
|
input_lengths,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
@ -83,9 +82,7 @@ def paged_attention(
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,9 +1,38 @@
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.layers.fp8 import fp8_quantize
|
||||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import reshape_and_cache
|
from text_generation_server.utils.log import log_once
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVScales:
|
||||||
|
"""
|
||||||
|
Key-value scales for FP8 KV cache.
|
||||||
|
|
||||||
|
This data class stores key and value scales both as a GPU tensor and
|
||||||
|
as a GPU float. This inconvenience is necessary because some functions
|
||||||
|
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
||||||
|
(e.g. flashinfer) take scales as a CPU scalar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key_scale: torch.Tensor
|
||||||
|
value_scale: torch.Tensor
|
||||||
|
key_scale_cpu: float = field(init=False)
|
||||||
|
value_scale_cpu: float = field(init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
||||||
|
raise ValueError("Key and value scales must be scalar tensors.")
|
||||||
|
|
||||||
|
self.key_scale_cpu = self.key_scale.item()
|
||||||
|
self.value_scale_cpu = self.value_scale.item()
|
||||||
|
|
||||||
|
|
||||||
class KVCache:
|
class KVCache:
|
||||||
|
@ -24,11 +53,11 @@ class KVCache:
|
||||||
):
|
):
|
||||||
"""Construct the key-value cache for a layer."""
|
"""Construct the key-value cache for a layer."""
|
||||||
|
|
||||||
if dtype == torch.float8_e5m2 and (
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
|
||||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
"FP8 KV cache is currently only supported for flashinfer on CUDA"
|
||||||
)
|
)
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
@ -77,6 +106,33 @@ class KVCache:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def can_scale(self, kv_scales: KVScales) -> bool:
|
||||||
|
"""Check if the cache can be scaled by the given scales."""
|
||||||
|
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
||||||
|
return False
|
||||||
|
elif (
|
||||||
|
self.dtype == torch.float8_e4m3fn
|
||||||
|
and ATTENTION == "flashinfer"
|
||||||
|
and SYSTEM == "cuda"
|
||||||
|
):
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Using FP8 KV cache scales",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
"""Get the data type of the cache."""
|
||||||
|
return self.kv_cache[0].dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self):
|
||||||
"""Get the key cache."""
|
"""Get the key cache."""
|
||||||
|
@ -95,18 +151,34 @@ class KVCache:
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
kv_scales: KVScales,
|
||||||
):
|
):
|
||||||
"""Store the key and value at the given slots."""
|
"""Store the key and value at the given slots."""
|
||||||
|
|
||||||
key_cache = self.kv_cache[0]
|
key_cache = self.kv_cache[0]
|
||||||
value_cache = self.kv_cache[1]
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
|
if self.can_scale(kv_scales):
|
||||||
|
if kv_scales.key_scale_cpu != 1.0:
|
||||||
|
key = fp8_quantize(
|
||||||
|
key.float(),
|
||||||
|
scale=kv_scales.key_scale,
|
||||||
|
qdtype=self.dtype,
|
||||||
|
scalar=True,
|
||||||
|
)[0]
|
||||||
|
if kv_scales.value_scale_cpu != 1.0:
|
||||||
|
value = fp8_quantize(
|
||||||
|
value.float(),
|
||||||
|
scale=kv_scales.value_scale,
|
||||||
|
qdtype=self.dtype,
|
||||||
|
scalar=True,
|
||||||
|
)[0]
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
# TODO: add scale
|
|
||||||
key = key.to(key_cache.dtype)
|
key = key.to(key_cache.dtype)
|
||||||
value = value.to(value_cache.dtype)
|
value = value.to(value_cache.dtype)
|
||||||
if key_cache.dtype == torch.float8_e5m2:
|
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||||
# Torch index_put does not support float8_e5m2 yet, so
|
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
|
||||||
# put as raw data instead.
|
# put as raw data instead.
|
||||||
key_cache = key_cache.view(torch.uint8)
|
key_cache = key_cache.view(torch.uint8)
|
||||||
value_cache = value_cache.view(torch.uint8)
|
value_cache = value_cache.view(torch.uint8)
|
||||||
|
@ -116,4 +188,59 @@ class KVCache:
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
|
)
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||||
|
"""Load KV cache scales."""
|
||||||
|
|
||||||
|
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
||||||
|
value_scale = key_scale
|
||||||
|
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
||||||
|
f"{prefix}.v_scale"
|
||||||
|
):
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
||||||
|
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
||||||
|
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
||||||
|
# Fall back to older more coarse-grained scale when available.
|
||||||
|
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
||||||
|
value_scale = key_scale
|
||||||
|
|
||||||
|
return KVScales(key_scale=key_scale, value_scale=value_scale)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -16,8 +16,6 @@ _PARTITION_SIZE_CUSTOM = 256
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
PREFILL_IN_KV_CACHE = False
|
|
||||||
|
|
||||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||||
try:
|
try:
|
||||||
if use_rocm_custom_paged_attn:
|
if use_rocm_custom_paged_attn:
|
||||||
|
@ -29,38 +27,17 @@ except ImportError as e:
|
||||||
)
|
)
|
||||||
use_rocm_custom_paged_attn = False
|
use_rocm_custom_paged_attn = False
|
||||||
|
|
||||||
try:
|
|
||||||
import vllm._custom_ops as ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if ATTENTION == "flashdecoding":
|
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: KVCache,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
*,
|
||||||
|
kv_scales: KVScales,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
@ -84,10 +61,10 @@ def paged_attention(
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = value_cache.shape[3]
|
block_size = kv_cache.value.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
|
||||||
num_kv_heads = key_cache.shape[1]
|
num_kv_heads = kv_cache.key.shape[1]
|
||||||
gqa_ratio = num_heads // num_kv_heads
|
gqa_ratio = num_heads // num_kv_heads
|
||||||
use_custom = (
|
use_custom = (
|
||||||
use_rocm_custom_paged_attn
|
use_rocm_custom_paged_attn
|
||||||
|
@ -104,7 +81,7 @@ def paged_attention(
|
||||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||||
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
@ -124,8 +101,8 @@ def paged_attention(
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -158,8 +135,8 @@ def paged_attention(
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -177,8 +154,8 @@ def paged_attention(
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -227,29 +204,36 @@ if ENGINE != "triton":
|
||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
if ENGINE == "ck":
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
def attention(
|
||||||
key_cache: torch.Tensor,
|
*,
|
||||||
value_cache: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
kv_scales: KVScales,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
window_size_left: int = -1,
|
window_size_left: int = -1,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap: float = 0.0,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if ENGINE == "ck":
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
query,
|
||||||
key_cache,
|
key,
|
||||||
value_cache,
|
value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -270,30 +254,19 @@ if ENGINE == "ck":
|
||||||
None,
|
None,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
elif ENGINE == "triton":
|
elif ENGINE == "triton":
|
||||||
from .flash_attn_triton import triton_attention
|
from .flash_attn_triton import triton_attention
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
query,
|
||||||
key_cache,
|
key,
|
||||||
value_cache,
|
value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -304,13 +277,12 @@ elif ENGINE == "triton":
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
from .ipex import WQLinear
|
||||||
|
elif SYSTEM == "cuda":
|
||||||
|
from .cuda import WQLinear
|
||||||
|
|
||||||
|
__all__ = ["WQLinear"]
|
|
@ -0,0 +1,48 @@
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
|
class WQLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if w_bit not in [4]:
|
||||||
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
|
self.in_features = qweight.shape[0]
|
||||||
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
|
self.w_bit = w_bit
|
||||||
|
self.group_size = group_size if group_size != -1 else self.in_features
|
||||||
|
# quick sanity check (make sure aligment)
|
||||||
|
assert self.in_features % self.group_size == 0
|
||||||
|
assert self.out_features % (32 // self.w_bit) == 0
|
||||||
|
|
||||||
|
self.qweight = qweight
|
||||||
|
self.qzeros = qzeros
|
||||||
|
self.scales = scales
|
||||||
|
self.bias = bias
|
||||||
|
self.woq_linear = (
|
||||||
|
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.qzeros,
|
||||||
|
self.in_features,
|
||||||
|
self.out_features,
|
||||||
|
bias=self.bias,
|
||||||
|
group_size=self.group_size,
|
||||||
|
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
|
||||||
|
dtype=ipex.llm.quantization.QuantDtype.INT4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
|
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out.reshape(out_shape)
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Tuple, Union, List
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import marlin_kernels
|
||||||
|
except ImportError:
|
||||||
|
marlin_kernels = None
|
||||||
|
|
||||||
|
|
||||||
if is_fbgemm_gpu_available():
|
if is_fbgemm_gpu_available():
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
major, _ = torch.cuda.get_device_capability()
|
major, _ = torch.cuda.get_device_capability()
|
||||||
|
@ -51,27 +57,82 @@ def get_fp8_linear() -> torch.nn.Module:
|
||||||
return Fp8Linear
|
return Fp8Linear
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
assert weight.dtype == torch.float8_e4m3fn
|
||||||
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_as_int8 = weight.view(torch.int8)
|
||||||
|
ROCM_FP8_NAN_AS_INT = -128
|
||||||
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||||
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
# For the same bits representation, e4m3fnuz value is half of
|
||||||
|
# the e4m3fn value, so we should double the scaling factor to
|
||||||
|
# get the same dequantized value.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_scale = weight_scale * 2.0
|
||||||
|
if input_scale is not None:
|
||||||
|
input_scale = input_scale * 2.0
|
||||||
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def fp8_quantize(
|
||||||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
weight: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
|
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
|
scalar: bool = False,
|
||||||
):
|
):
|
||||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
"""
|
||||||
|
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||||
|
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||||
|
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||||
|
be used without modification).
|
||||||
|
"""
|
||||||
|
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
|
||||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||||
)
|
)
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
if marlin_kernels is not None:
|
||||||
|
shape = weight.shape
|
||||||
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||||
|
weight.reshape(-1, shape[-1]),
|
||||||
|
dtype=qdtype,
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=scale_upper_bound,
|
||||||
|
)
|
||||||
|
|
||||||
|
return qweight.reshape(shape), scale
|
||||||
|
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
finfo = torch.finfo(qdtype)
|
finfo = torch.finfo(qdtype)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
# Calculate the scale as dtype max divided by absmax
|
# Calculate the scale as dtype max divided by absmax
|
||||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||||
# scale and clamp the tensor to bring it to
|
# scale and clamp the tensor to bring it to
|
||||||
# the representative range of float8 data type
|
# the representative range of float8 data type
|
||||||
# (as default cast is unsaturated)
|
# (as default cast is unsaturated)
|
||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
scale = scale.float().reciprocal()
|
||||||
|
else:
|
||||||
|
# Use reciprocal to avoid more expensive division.
|
||||||
|
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
|
||||||
# Return both float8 data and the inverse scale (as float),
|
# Return both float8 data and the inverse scale (as float),
|
||||||
# as both required as inputs to torch._scaled_mm
|
# as both required as inputs to torch._scaled_mm
|
||||||
qweight = qweight.to(qdtype)
|
qweight = qweight.to(qdtype)
|
||||||
scale = scale.float().reciprocal()
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
|
||||||
|
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,9 +153,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -125,9 +194,24 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
)
|
||||||
|
if input_scale.numel() > 1:
|
||||||
|
input_scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.input_scale",
|
||||||
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
input_scale = input_scale.reshape(-1).max()
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -154,9 +238,22 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
input_scale = [
|
||||||
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
|
for p, shape in zip(prefixes, shapes)
|
||||||
|
if weights.has_tensor(f"{p}.input_scale")
|
||||||
|
]
|
||||||
|
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||||
|
input_scale = (
|
||||||
|
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||||
|
if len(input_scale) != 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -174,9 +271,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
.expand(w.shape[0])
|
.expand(w.shape[0])
|
||||||
)
|
)
|
||||||
|
input_scale = None
|
||||||
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
|
input_scale = weights.get_tensor(
|
||||||
|
f"{prefix}.input_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
@ -191,6 +295,7 @@ class Fp8Weight(Weight):
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
weight_scale: Optional[torch.Tensor] = None
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
activation_scale_ub: Optional[float] = None
|
activation_scale_ub: Optional[float] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
|
@ -200,26 +305,41 @@ class Fp8Weight(Weight):
|
||||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
self.weight_scale = self.weight_scale.contiguous()
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
weight=self.weight,
|
||||||
|
scale=self.weight_scale,
|
||||||
|
dtype=self.dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=self.input_scale,
|
||||||
|
scale_upper_bound=self.activation_scale_ub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
|
_device_identity_cache = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qweight,
|
qweight: torch.Tensor,
|
||||||
scale,
|
scale: torch.Tensor,
|
||||||
scale_upper_bound,
|
dtype: torch.dtype,
|
||||||
bias,
|
bias: Optional[torch.Tensor] = None,
|
||||||
dtype,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
scale_upper_bound: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||||
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
|
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=qweight, weight_scale=scale
|
||||||
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.scale = scale
|
self.scale = scale.float()
|
||||||
|
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||||
|
|
||||||
|
if FBGEMM_MM_AVAILABLE:
|
||||||
self.scale_upper_bound = (
|
self.scale_upper_bound = (
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||||
|
@ -227,6 +347,8 @@ class Fp8Linear(torch.nn.Module):
|
||||||
if scale_upper_bound is not None
|
if scale_upper_bound is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@ -234,22 +356,46 @@ class Fp8Linear(torch.nn.Module):
|
||||||
def from_unquant(cls, weight, bias, dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||||
return cls(
|
return cls(
|
||||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
qweight=qweight,
|
||||||
|
scale=scale,
|
||||||
|
dtype=dtype,
|
||||||
|
bias=bias,
|
||||||
|
input_scale=None,
|
||||||
|
scale_upper_bound=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
def from_fp8(
|
||||||
|
cls,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Fp8Linear":
|
||||||
|
input_scale = kwargs.get("input_scale", None)
|
||||||
|
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||||
|
|
||||||
if FBGEMM_DYN_AVAILABLE:
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
# fbgemm needs float32 scales.
|
# fbgemm needs float32 scales.
|
||||||
scale = scale.float()
|
scale = scale.float()
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_upper_bound=input_scale,
|
input_scale=input_scale,
|
||||||
|
scale_upper_bound=scale_upper_bound,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_shared_device_identity(cls, device):
|
||||||
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
|
if device not in cls._device_identity_cache:
|
||||||
|
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||||
|
return cls._device_identity_cache[device]
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
|
@ -266,8 +412,18 @@ class Fp8Linear(torch.nn.Module):
|
||||||
)
|
)
|
||||||
return y.to(self.dtype)
|
return y.to(self.dtype)
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(input, scalar=True)
|
qinput, scale = fp8_quantize(
|
||||||
output, _ = torch._scaled_mm(
|
input,
|
||||||
|
self.input_scale,
|
||||||
|
scale_upper_bound=self.scale_upper_bound,
|
||||||
|
scalar=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_tensor_weights = self.scale.numel() == 1
|
||||||
|
per_tensor_activations = scale.numel() == 1
|
||||||
|
|
||||||
|
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
|
||||||
|
output = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
self.qweight.t(),
|
self.qweight.t(),
|
||||||
out_dtype=self.dtype,
|
out_dtype=self.dtype,
|
||||||
|
@ -275,6 +431,30 @@ class Fp8Linear(torch.nn.Module):
|
||||||
scale_b=self.scale,
|
scale_b=self.scale,
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
else:
|
||||||
|
device_identity = None
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
device_identity = self.get_shared_device_identity(self.qweight.device)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
self.qweight.t(),
|
||||||
|
scale_a=device_identity,
|
||||||
|
scale_b=device_identity,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
if isinstance(output, tuple) and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
output = output * scale * self.scale.t()
|
||||||
|
if self.bias is not None:
|
||||||
|
output = output + self.bias
|
||||||
|
|
||||||
|
output = output.to(dtype=self.dtype)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||||
|
|
||||||
|
if SYSTEM == "ipex":
|
||||||
|
from .ipex import QuantLinear
|
||||||
|
elif SYSTEM in {"cuda", "rocm"}:
|
||||||
|
from .triton import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQWeight(Weight):
|
class GPTQWeight(Weight):
|
||||||
|
@ -36,7 +41,7 @@ class GPTQWeight(Weight):
|
||||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
from text_generation_server.layers.awq.quantize import WQLinear
|
||||||
|
|
||||||
return WQLinear(
|
return WQLinear(
|
||||||
w_bit=self.bits,
|
w_bit=self.bits,
|
||||||
|
@ -60,8 +65,6 @@ class GPTQWeight(Weight):
|
||||||
|
|
||||||
return ExllamaQuantLinear(self, bias)
|
return ExllamaQuantLinear(self, bias)
|
||||||
else:
|
else:
|
||||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|
||||||
|
|
||||||
return QuantLinear(
|
return QuantLinear(
|
||||||
self.qweight,
|
self.qweight,
|
||||||
self.qzeros,
|
self.qzeros,
|
||||||
|
@ -298,6 +301,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
self._get_gptq_params(weights)
|
self._get_gptq_params(weights)
|
||||||
|
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
|
desc_act = self.desc_act
|
||||||
if self.bits != 4:
|
if self.bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
@ -321,7 +325,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
if (
|
if (
|
||||||
not torch.equal(
|
not torch.equal(
|
||||||
g_idx.cpu(),
|
# Remove g_idx[0] to adapt the check with TP>1.
|
||||||
|
(g_idx - g_idx[0]).cpu(),
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[i // self.groupsize for i in range(g_idx.shape[0])],
|
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
|
@ -332,6 +337,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
desc_act = True
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import (
|
from text_generation_server.layers.gptq import (
|
||||||
CAN_EXLLAMA,
|
CAN_EXLLAMA,
|
||||||
|
@ -350,16 +356,16 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
else:
|
else:
|
||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
if use_exllama and self.groupsize != -1:
|
if not desc_act and self.groupsize != -1:
|
||||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
if g_idx is not None:
|
||||||
|
# qzeros, scales sharded, and g_idx must be adjusted accordingly
|
||||||
|
g_idx = g_idx - g_idx[0]
|
||||||
else:
|
else:
|
||||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = weights.get_tensor(f"{prefix}.scales")
|
scales = weights.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
if use_exllama and g_idx is not None:
|
|
||||||
g_idx = g_idx - g_idx[0]
|
|
||||||
|
|
||||||
if self.quantize == "gptq" and self.quant_method == "awq":
|
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||||
log_once(
|
log_once(
|
||||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
@ -392,7 +398,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_gptq_params(self, weights: Weights):
|
def _get_gptq_params(self, weights: Weights):
|
||||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||||
self.bits = weights.get_tensor("gptq_bits").item()
|
self.bits = weights.get_tensor("gptq_bits").item()
|
||||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||||
self.desc_act = False
|
self.desc_act = False
|
||||||
|
@ -400,7 +406,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||||
# before the `gptq_sym` setting tensor was added.
|
# before the `gptq_sym` setting tensor was added.
|
||||||
self.sym = (
|
self.sym = (
|
||||||
weights.get_tensor("gptq_sym").item()
|
weights.get_tensor("gptq_sym").item()
|
||||||
if weights._has_tensor("gptq_sym")
|
if weights.has_tensor("gptq_sym")
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
self.quant_method = "gptq"
|
self.quant_method = "gptq"
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("qweight", qweight)
|
||||||
|
self.register_buffer("qzeros", qzeros)
|
||||||
|
self.register_buffer("scales", scales)
|
||||||
|
self.register_buffer("g_idx", g_idx)
|
||||||
|
if bias is not None:
|
||||||
|
self.register_buffer("bias", bias)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
self.bits = bits
|
||||||
|
self.maxq = 2**self.bits - 1
|
||||||
|
self.groupsize = groupsize
|
||||||
|
|
||||||
|
self.outfeatures = qweight.shape[1]
|
||||||
|
self.infeatures = qweight.shape[0] * 32 // bits
|
||||||
|
self.woq_linear = (
|
||||||
|
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||||
|
self.qweight,
|
||||||
|
self.scales,
|
||||||
|
self.qzeros,
|
||||||
|
self.infeatures,
|
||||||
|
self.outfeatures,
|
||||||
|
bias=self.bias,
|
||||||
|
group_size=self.groupsize,
|
||||||
|
g_idx=g_idx,
|
||||||
|
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
|
||||||
|
dtype=ipex.llm.quantization.QuantDtype.INT4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||||
|
if bits not in [4]:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||||
|
qzeros = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
scales = torch.zeros(
|
||||||
|
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||||
|
)
|
||||||
|
g_idx = torch.tensor(
|
||||||
|
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||||
|
|
||||||
|
def pack(self, linear, scales, zeros, g_idx=None):
|
||||||
|
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||||
|
|
||||||
|
scales = scales.t().contiguous()
|
||||||
|
zeros = zeros.t().contiguous()
|
||||||
|
scale_zeros = zeros * scales
|
||||||
|
self.scales = scales.clone().half()
|
||||||
|
if linear.bias is not None:
|
||||||
|
self.bias = linear.bias.clone().half()
|
||||||
|
|
||||||
|
intweight = []
|
||||||
|
for idx in range(self.infeatures):
|
||||||
|
intweight.append(
|
||||||
|
torch.round(
|
||||||
|
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||||
|
/ self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
|
intweight = torch.cat(intweight, dim=1)
|
||||||
|
intweight = intweight.t().contiguous()
|
||||||
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
qweight = np.zeros(
|
||||||
|
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
row = 0
|
||||||
|
while row < qweight.shape[0]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
row += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qweight = qweight.astype(np.int32)
|
||||||
|
self.qweight = torch.from_numpy(qweight)
|
||||||
|
|
||||||
|
zeros -= 1
|
||||||
|
zeros = zeros.numpy().astype(np.uint32)
|
||||||
|
qzeros = np.zeros(
|
||||||
|
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||||
|
)
|
||||||
|
i = 0
|
||||||
|
col = 0
|
||||||
|
while col < qzeros.shape[1]:
|
||||||
|
if self.bits in [4]:
|
||||||
|
for j in range(i, i + (32 // self.bits)):
|
||||||
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||||
|
i += 32 // self.bits
|
||||||
|
col += 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 4 bits are supported.")
|
||||||
|
|
||||||
|
qzeros = qzeros.astype(np.int32)
|
||||||
|
self.qzeros = torch.from_numpy(qzeros)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||||
|
out = out + self.bias if self.bias is not None else out
|
||||||
|
return out.reshape(out_shape)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue