fix docker

This commit is contained in:
Mohit Sharma 2024-09-12 15:45:06 +00:00
commit 4ba9210f91
33 changed files with 4396 additions and 1251 deletions

41
.github/workflows/nix_tests.yaml vendored Normal file
View File

@ -0,0 +1,41 @@
name: "Nix Tests"
on:
pull_request:
paths:
- ".github/workflows/nix_tests.yaml"
- "server/**"
- "proto/**"
- "router/**"
- "launcher/**"
- "Cargo.lock"
- "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
tests:
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v27
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
with:
name: text-generation-inference
# If you chose signing key for write access
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
env:
USER: github_runner
- name: Build
run: nix develop .#test --command echo "Ok"
- name: Pre-commit tests.
run: nix develop .#test --command pre-commit run --all-files
- name: Python tests.
run: nix develop .#test --command python -m pytest server/tests/
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests.
run: nix develop .#test --command cargo test

View File

@ -17,19 +17,15 @@ concurrency:
jobs:
run_tests:
runs-on: ubuntu-latest
env:
SCCACHE_GHA_ENABLED: "on"
RUSTC_WRAPPER: /usr/local/bin/sccache
SCCACHE: 0.3.3
runs-on:
group: aws-highmemory-32-plus-priv
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
uses: actions/setup-python@v4
id: python
with:
python-version: 3.9
python-version: 3.11
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
@ -44,30 +40,9 @@ jobs:
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install sccache
run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
chmod +x /usr/local/bin/sccache
- name: configure sccache
uses: actions/github-script@v6
with:
script: |
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
- name: cargo registry cache
uses: actions/cache@v3
with:
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
restore-keys: |
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
cargo-${{ runner.os }}-
path: |
~/.cargo/registry
~/.cargo/git
- name: Install
run: |
sudo apt install python3.11-dev -y
make install-cpu
- name: Run server tests
run: |
@ -82,6 +57,3 @@ jobs:
- name: Run Rust tests
run: |
cargo test
- name: sccache stats
run: |
/usr/local/bin/sccache --show-stats

120
Cargo.lock generated
View File

@ -2118,6 +2118,15 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "metrics"
version = "0.23.0"
@ -3112,6 +3121,69 @@ dependencies = [
"prost 0.12.6",
]
[[package]]
name = "pyo3"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
dependencies = [
"cfg-if",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.76",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.76",
]
[[package]]
name = "qoi"
version = "0.4.1"
@ -4068,7 +4140,7 @@ dependencies = [
"pkg-config",
"text-generation-router",
"thiserror",
"tokenizers",
"tokenizers 0.19.1",
"tokio",
"tokio-stream",
"tracing",
@ -4091,7 +4163,7 @@ dependencies = [
"tabled",
"text-generation-client",
"thiserror",
"tokenizers",
"tokenizers 0.20.0",
"tokio",
"tracing",
"tracing-subscriber",
@ -4161,6 +4233,7 @@ dependencies = [
"once_cell",
"opentelemetry 0.20.0",
"opentelemetry-otlp",
"pyo3",
"rand",
"regex",
"reqwest",
@ -4168,7 +4241,7 @@ dependencies = [
"serde_json",
"sysinfo",
"thiserror",
"tokenizers",
"tokenizers 0.20.0",
"tokio",
"tokio-stream",
"tower-http",
@ -4219,7 +4292,7 @@ dependencies = [
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers",
"tokenizers 0.20.0",
"tokio",
"tokio-stream",
"tonic 0.10.2",
@ -4374,6 +4447,39 @@ dependencies = [
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8a24d7f7d6be5b9d1377418b893ab1808af0074f5d1bb2c64784452ddd2aa70"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokio"
version = "1.39.3"
@ -4839,6 +4945,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "untrusted"
version = "0.7.1"

View File

@ -25,7 +25,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
[workspace.dependencies]
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] }
tokenizers = { version = "0.20.0", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }

View File

@ -13,10 +13,13 @@ COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -37,6 +40,7 @@ COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
@ -45,7 +49,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0
ARG PYTHON_VERSION=3.10
ARG PYTHON_VERSION=3.11
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=24.3.0-0
@ -216,33 +220,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
@ -257,7 +261,9 @@ RUN cd server && \
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
# Required to find libpython within the rust binaries
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1

View File

@ -17,6 +17,8 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -65,15 +67,15 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
hipsolver-dev \
rccl-dev \
cmake \
python3-dev \
python3-venv && \
python3.11-dev \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH=/opt/conda/bin:$PATH
@ -85,22 +87,6 @@ RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMa
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
&& rm cmake-3.30.2-linux-x86_64.sh
RUN pip install joblib msgpack
# Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& cd build/release \
&& make package
RUN dpkg -i hipBLASLt/build/release/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status;
# && cd .. \
# && rm -rf hipBLASLt
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
# translating Docker's TARGETPLATFORM into mamba arches
@ -114,10 +100,32 @@ RUN chmod +x ~/mambaforge.sh && \
mamba init && \
rm ~/mambaforge.sh
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
# RUN conda install intel::mkl-static intel::mkl-include
# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja joblib msgpack --no-cache-dir
# Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& cd build/release \
&& make package
RUN dpkg -i hipBLASLt/build/release/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
&& rm -rf hipBLASLt
RUN conda install mkl-static mkl-include
RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
@ -201,19 +209,19 @@ ENV HF_HOME=/data \
PORT=80
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server
COPY proto proto
@ -230,6 +238,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# AWS Sagemaker compatible image
FROM base AS sagemaker

View File

@ -18,6 +18,8 @@ RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3.11-dev
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
@ -114,7 +116,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80
ARG MAMBA_VERSION=23.1.0-1
ARG PYTHON_VERSION='3.10.10'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
@ -153,6 +155,7 @@ ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
# Install server
COPY proto proto

View File

@ -376,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");

View File

@ -357,6 +357,7 @@ impl State {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget

View File

@ -123,8 +123,6 @@ impl Allocator for RadixAllocator {
prefill_tokens: prefill_tokens.clone(),
};
tracing::debug!("Blocks {blocks:?}");
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);

View File

@ -492,6 +492,24 @@
"type": "github"
}
},
"flake-utils_7": {
"inputs": {
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"gitignore": {
"inputs": {
"nixpkgs": [
@ -700,16 +718,16 @@
},
"nixpkgs_6": {
"locked": {
"lastModified": 1723912943,
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=",
"owner": "danieldk",
"lastModified": 1724915739,
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c",
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "cuda-12.4",
"owner": "nixos",
"ref": "nixos-unstable-small",
"repo": "nixpkgs",
"type": "github"
}
@ -835,11 +853,11 @@
]
},
"locked": {
"lastModified": 1724638882,
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
"lastModified": 1726021481,
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
"type": "github"
},
"original": {
@ -938,17 +956,33 @@
"type": "github"
}
},
"systems_7": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat_4",
"flake-utils": "flake-utils_7",
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1725011596,
"narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=",
"lastModified": 1725950569,
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "717c2b07e38538abf05237cca65b2d1363c2c9af",
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
"type": "github"
},
"original": {

View File

@ -46,12 +46,30 @@
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
inherit crateOverrides;
};
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
inherit crateOverrides;
};
router =
let
routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
inherit crateOverrides;
};
packagePath =
with pkgs.python3.pkgs;
makePythonPath [
protobuf
sentencepiece
torch
transformers
];
in
pkgs.writeShellApplication {
name = "text-generation-router";
text = ''
PYTHONPATH="${packagePath}" ${routerUnwrapped}/bin/text-generation-router "$@"
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
in
{
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = pure;
@ -63,6 +81,29 @@
server
];
};
test = mkShell {
buildInputs =
[
# benchmark
# launcher
# router
server
openssl.dev
pkg-config
cargo
rustfmt
clippy
]
++ (with python3.pkgs; [
docker
pytest
pytest-asyncio
syrupy
pre-commit
ruff
]);
};
impure = mkShell {
buildInputs =
@ -82,6 +123,7 @@
docker
pip
ipdb
click
pyright
pytest
pytest-asyncio

View File

@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient
from text_generation.types import (
BestOfSequence,
Message,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
@ -97,25 +98,25 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool:
def convert_data(data):
data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
return _convert_data(data)
def _convert_data(data):
if isinstance(data, Dict):
return Response(**data)
if "choices" in data:
data["choices"] = list(
sorted(data["choices"], key=lambda x: x["index"])
)
choices = data["choices"]
if isinstance(choices, List) and len(choices) >= 1:
if "delta" in choices[0]:
return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data)
else:
return Response(**data)
if isinstance(data, List):
if (
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data]
return [_convert_data(d) for d in data]
raise NotImplementedError
def eq_token(token: Token, other: Token) -> bool:
@ -571,3 +572,38 @@ def generate_load():
return await asyncio.gather(*futures)
return generate_load_inner
@pytest.fixture(scope="module")
def generate_multi():
async def generate_load_inner(
client: AsyncClient,
prompts: List[str],
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:
import numpy as np
arange = np.arange(len(prompts))
perm = np.random.permutation(arange)
rperm = [-1] * len(perm)
for i, p in enumerate(perm):
rperm[p] = i
shuffled_prompts = [prompts[p] for p in perm]
futures = [
client.chat(
messages=[Message(role="user", content=prompt)],
max_tokens=max_new_tokens,
temperature=0,
seed=seed,
)
for prompt in shuffled_prompts
]
shuffled_responses = await asyncio.gather(*futures)
responses = [shuffled_responses[p] for p in rperm]
return responses
return generate_load_inner

View File

@ -1,38 +1,38 @@
{
"choices": [
{
"finish_reason": "stop",
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " A Beginners Guide\nDeep learning is a subset"
},
{
"finish_reason": "length",
"index": 1,
"logprobs": null,
"text": " PR for more information?"
"text": " This is a question that has puzzled many people for"
},
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": "hd20220811-"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
"text": "usculas_minusculas(s):\n \"\"\"\n"
},
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": " severely flawed and often has a substandard"
"text": " Paris\nWhat is the capital of France?\nThe"
}
],
"created": 1722014725,
"created": 1725877154,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 8,
"total_tokens": 44
"completion_tokens": 40,
"prompt_tokens": 22,
"total_tokens": 62
}
}

View File

@ -5,12 +5,12 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
"text": " A"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -20,12 +20,72 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
"text": " This"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " Paris"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "us"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Beginner"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -38,9 +98,9 @@
"text": "\n"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -50,12 +110,12 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "hd"
"text": "cul"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -65,12 +125,12 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
"text": "s"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -80,12 +140,12 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
"text": " a"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -95,12 +155,12 @@
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
"text": "What"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -110,12 +170,12 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "aho"
"text": "as"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -125,12 +185,12 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "2"
"text": " Guide"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -140,252 +200,12 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "2"
"text": " question"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1724833943,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -398,9 +218,9 @@
"text": " is"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -410,12 +230,12 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "m"
"text": "_minus"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -425,12 +245,12 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Room"
"text": "\n"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -440,12 +260,12 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "s"
"text": " that"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -458,9 +278,9 @@
"text": " the"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -470,12 +290,12 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " tired"
"text": "cul"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -485,12 +305,12 @@
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": ":"
"text": "Deep"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -500,12 +320,12 @@
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "'"
"text": " has"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -518,9 +338,9 @@
"text": " capital"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -530,12 +350,192 @@
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": ","
"text": "as"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " learning"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " puzzled"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " of"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "(s"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " many"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " France"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "):\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " people"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "?\n"
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " "
}
],
"created": 1725883643,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -545,12 +545,12 @@
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " She"
"text": " subset"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -560,12 +560,12 @@
"finish_reason": "length",
"index": 1,
"logprobs": null,
"text": " scale"
"text": " for"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -575,12 +575,12 @@
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": " of"
"text": "The"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
},
@ -590,12 +590,12 @@
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": " its"
"text": " \"\"\"\n"
}
],
"created": 1724833943,
"created": 1725883643,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.2.1-dev0-native"
}

View File

@ -4,17 +4,17 @@
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " PR for flake8"
"text": " A Beginners Guide\nDeep learning is a subset"
}
],
"created": 1713284454,
"created": 1725876621,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "text_completion",
"system_fingerprint": "2.0.1-native",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 5,
"completion_tokens": 10,
"prompt_tokens": 6,
"total_tokens": 11
"total_tokens": 16
}
}

View File

@ -11,7 +11,7 @@ from text_generation.types import (
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
) as handle:
yield handle
@ -34,16 +34,19 @@ def test_flash_llama_completion_single_prompt(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "Say this is a test",
"max_tokens": 5,
"seed": 0,
"prompt": "What is Deep Learning?",
"max_tokens": 10,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert (
response["choices"][0]["text"]
== " A Beginners Guide\nDeep learning is a subset"
)
assert response == response_snapshot
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a"],
"prompt": [
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"temperature": 0.0,
},
headers=flash_llama_completion.headers,
stream=False,
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]]
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
all_indexes.sort()
assert all_indexes == [0, 1, 2, 3]
all_indices, all_strings = zip(*all_indexes)
assert list(all_indices) == [0, 1, 2, 3]
assert list(all_strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert response == response_snapshot
@ -77,19 +93,21 @@ async def test_flash_llama_completion_many_prompts_stream(
request = {
"model": "tgi",
"prompt": [
"What color is the sky?",
"What is Deep Learning?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"temperature": 0.0,
"stream": True,
}
url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = []
strings = [""] * 4
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
@ -108,7 +126,15 @@ async def test_flash_llama_completion_many_prompts_stream(
for c in chunk:
chunks.append(Completion(**c))
assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4
index = c["choices"][0]["index"]
assert 0 <= index <= 4
strings[index] += c["choices"][0]["text"]
assert response.status == 200
assert list(strings) == [
" A Beginners Guide\nDeep learning is a subset",
" This is a question that has puzzled many people for",
" Paris\nWhat is the capital of France?\nThe",
'usculas_minusculas(s):\n """\n',
]
assert chunks == response_snapshot

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -6,9 +6,10 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies]
pydantic = "> 2, < 3"
python = ">=3.9,<3.13"
python = ">=3.10,<3.13"
syrupy = "^4.7.1"
text-generation = "^0.6.0"
pytest = "^7.4.0"
pytest-asyncio = "^0.21.1"
docker = "^6.1.3"
docker = "^7"
numpy = "^1.20"

View File

@ -1,34 +1,35 @@
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.0 ; python_version >= "3.10" and python_version < "3.13"
aiohttp==3.10.5 ; python_version >= "3.10" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
docker==7.1.0 ; python_version >= "3.10" and python_version < "3.13"
exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
filelock==3.16.0 ; python_version >= "3.10" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "3.13"
fsspec==2024.9.0 ; python_version >= "3.10" and python_version < "3.13"
huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "3.13"
multidict==6.1.0 ; python_version >= "3.10" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.10" and python_version < "3.13"
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "3.13"
pydantic-core==2.23.3 ; python_version >= "3.10" and python_version < "3.13"
pydantic==2.9.1 ; python_version >= "3.10" and python_version < "3.13"
pytest-asyncio==0.21.2 ; python_version >= "3.10" and python_version < "3.13"
pytest==7.4.4 ; python_version >= "3.10" and python_version < "3.13"
pywin32==306 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13"
syrupy==4.7.1 ; python_version >= "3.10" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.10" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
tqdm==4.66.5 ; python_version >= "3.10" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
yarl==1.11.1 ; python_version >= "3.10" and python_version < "3.13"

View File

@ -1843,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
shutdown.clone(),
&shutdown_receiver,
)
.map_err(|err| {
.inspect_err(|_| {
shutdown_shards(shutdown.clone(), &shutdown_receiver);
err
})?;
// Default exit code

View File

@ -28,6 +28,9 @@ defaultCrateOverrides
];
};
};
pyo3-build-config = attrs: {
buildInputs = [ python3 ];
};
text-generation-benchmark = attrs: {
src = filter {
root = ../benchmark;

View File

@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
] }
csv = "1.3.0"
ureq = "=2.9"
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
[build-dependencies]

View File

@ -336,6 +336,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Incomplete generation stream")]
IncompleteGenerationStream,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Missing template vatiable: {0}")]
@ -351,6 +353,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error",

View File

@ -41,6 +41,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use pyo3::types::IntoPyDict;
use serde_json::Value;
use std::convert::Infallible;
use std::fs::File;
@ -48,7 +49,6 @@ use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
tracing::debug!(
"Input: {}",
&req.inputs.chars().take(1000).collect::<String>()
);
let compute_characters = req.inputs.chars().count();
let mut add_prompt = None;
@ -674,7 +677,7 @@ async fn generate_stream_internal(
// Check if generation reached the end
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
let err = InferError::IncompleteGenerationStream;
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
@ -1857,18 +1860,34 @@ pub async fn run(
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
use pyo3::prelude::*;
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
let transformers = py.import_bound("transformers")?;
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
"revision",
revision.clone().unwrap_or_else(|| "main".to_string()),
)]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
let args = ("out".to_string(),);
save.call1(args)?;
Ok(())
})
.inspect_err(|err| {
tracing::error!("Failed to import python tokenizer {err}");
});
let filename = if convert.is_ok() {
// If we have correctly loaded and resaved with transformers
// We might have modified the tokenizer.json according to transformers
"out/tokenizer.json".into()
} else {
filename
};
Tokenizer::from_file(filename).ok()
});
let config: Option<Config> = config_filename.and_then(|filename| {
@ -2555,6 +2574,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
@ -2587,77 +2607,6 @@ pub enum WebServerError {
Axum(#[from] axum::BoxError),
}
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input(

View File

@ -1,2 +1,2 @@
install-flashinfer:
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -267,7 +267,7 @@ def test_batch_concatenate(
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -262,7 +262,7 @@ def test_batch_concatenate(
assert next_batch.max_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -281,7 +281,7 @@ def test_batch_concatenate(
assert next_batch.max_decoder_input_length == 3
assert next_batch.requests[0] == next_batch_0.requests[0]
assert next_batch.requests[1:] == next_batch_1.requests
assert next_batch.requests[1:] == list(next_batch_1.requests)
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers

View File

@ -272,6 +272,8 @@ class FlashCausalLMBatch(Batch):
assert prefix_len > 0
prefix_len -= 1
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
@ -515,6 +517,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
assert len(pb.requests) > 0
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@ -640,6 +643,7 @@ class FlashCausalLMBatch(Batch):
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
return type(self)(
batch_id=self.batch_id,
@ -834,6 +838,8 @@ class FlashCausalLMBatch(Batch):
start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
@ -1150,27 +1156,6 @@ class FlashCausalLM(Model):
input_lengths=input_lengths,
prefix_lens=prefix_lengths,
)
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
}
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
)
@ -1187,21 +1172,38 @@ class FlashCausalLM(Model):
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.cuda_graphs[bs]["state"] = state
else:
state = None
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize()
# Run once outside to warmup
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -1214,6 +1216,7 @@ class FlashCausalLM(Model):
prefill_cache_indices=None,
lm_head_indices=None,
)
del seqlen
torch.cuda.synchronize()
@ -1484,9 +1487,7 @@ class FlashCausalLM(Model):
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths + prefix_lens_tensor,
prefix_lens=batch.prefix_lens,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
@ -1524,26 +1525,28 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
# assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["prefix_lengths"].zero_()
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths=batch.input_lengths,
input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
state=cuda_graph.get("state"),
prefix_lens_tensor=cuda_graph["prefix_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
@ -1772,7 +1775,7 @@ class FlashCausalLM(Model):
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
@ -1927,9 +1930,7 @@ class FlashCausalLM(Model):
*,
block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: List[int],
input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None,
) -> ContextManager:
@ -1955,7 +1956,7 @@ class FlashCausalLM(Model):
# ),
block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
@ -1965,7 +1966,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None
return use_decode_state(
state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
block_tables=block_tables,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,

View File

@ -367,9 +367,7 @@ class VlmCausalLM(FlashCausalLM):
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()