Merge branch 'main' into feat/add-load-test

This commit is contained in:
Hugo Larcher 2024-08-30 15:31:55 +02:00
commit 345d47362f
No known key found for this signature in database
GPG Key ID: 3DAF63124699CA2B
144 changed files with 39803 additions and 35173 deletions

View File

@ -32,10 +32,6 @@ jobs:
permissions: permissions:
contents: write contents: write
packages: write packages: write
# This is used to complete the identity challenge
# with sigstore/fulcio when running outside of PRs.
id-token: write
security-events: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -50,6 +46,7 @@ jobs:
export label_extension="" export label_extension=""
export docker_devices="" export docker_devices=""
export runs_on="aws-g6-12xlarge-plus-priv" export runs_on="aws-g6-12xlarge-plus-priv"
export platform=""
;; ;;
rocm) rocm)
export dockerfile="Dockerfile_amd" export dockerfile="Dockerfile_amd"
@ -58,12 +55,21 @@ jobs:
# TODO Re-enable when they pass. # TODO Re-enable when they pass.
# export runs_on="amd-gpu-tgi" # export runs_on="amd-gpu-tgi"
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform=""
;; ;;
intel) intel-xpu)
export dockerfile="Dockerfile_intel" export dockerfile="Dockerfile_intel"
export label_extension="-intel" export label_extension="-intel-xpu"
export docker_devices="" export docker_devices=""
export runs_on="ubuntu-latest" export runs_on="ubuntu-latest"
export platform="xpu"
;;
intel-cpu)
export dockerfile="Dockerfile_intel"
export label_extension="-intel-cpu"
export docker_devices=""
export runs_on="ubuntu-latest"
export platform="cpu"
;; ;;
esac esac
echo $dockerfile echo $dockerfile
@ -71,8 +77,10 @@ jobs:
echo $label_extension echo $label_extension
echo $docker_devices echo $docker_devices
echo $runs_on echo $runs_on
echo $platform
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
echo "LABEL=${label_extension}" >> $GITHUB_ENV echo "LABEL=${label_extension}" >> $GITHUB_ENV
echo "PLATFORM=${platform}" >> $GITHUB_ENV
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
@ -139,6 +147,7 @@ jobs:
build-args: | build-args: |
GIT_SHA=${{ env.GITHUB_SHA }} GIT_SHA=${{ env.GITHUB_SHA }}
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
PLATFORM=${{ env.PLATFORM }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
@ -159,7 +168,7 @@ jobs:
group: ${{ needs.build-and-push.outputs.runs_on }} group: ${{ needs.build-and-push.outputs.runs_on }}
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env: env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@ -11,7 +11,7 @@ concurrency:
jobs: jobs:
build: build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yaml@main uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with: with:
commit_sha: ${{ github.event.pull_request.head.sha }} commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }} pr_number: ${{ github.event.number }}

View File

@ -37,8 +37,11 @@ jobs:
# fail-fast is true by default # fail-fast is true by default
fail-fast: false fail-fast: false
matrix: matrix:
hardware: ["cuda", "rocm", "intel"] hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
uses: ./.github/workflows/build.yaml # calls the one above ^ uses: ./.github/workflows/build.yaml # calls the one above ^
permissions:
contents: write
packages: write
with: with:
hardware: ${{ matrix.hardware }} hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206 # https://github.com/actions/runner/issues/2206

View File

@ -35,7 +35,7 @@ jobs:
with: with:
# Released on: 02 May, 2024 # Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.78.0/
toolchain: 1.79.0 toolchain: 1.80.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc - name: Install Protoc

3
.gitignore vendored
View File

@ -9,7 +9,7 @@ backends/client/src/v3/pb
# ROCm auto-generated files # ROCm auto-generated files
*.hip *.hip
server/exllamav2_kernels/exllamav2_kernels/hip/ server/exllamav2
server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip/
server/exllama_kernels/exllama_kernels/hip_func/ server/exllama_kernels/exllama_kernels/hip_func/
*_hip.cuh *_hip.cuh
@ -18,3 +18,4 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
data/ data/
load_tests/*.json load_tests/*.json
server/fbgemmm

View File

@ -5,7 +5,7 @@ repos:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
- id: trailing-whitespace - id: trailing-whitespace
exclude: docs/source/basic_tutorials/launcher.md exclude: docs/source/reference/launcher.md
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 24.2.0 rev: 24.2.0
hooks: hooks:

View File

@ -77,3 +77,4 @@ docs/openapi.json:
- '#/paths/~1tokenize/post' - '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post' - '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post' - '#/paths/~1v1~1completions/post'
- '#/paths/~1v1~1models/get'

607
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -29,6 +29,8 @@ tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] } hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" } metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] } metrics-exporter-prometheus = { version = "0.15.1", features = [] }
minijinja = { version = "2.2.0", features = ["json"] }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
[profile.release] [profile.release]
incremental = true incremental = true

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -40,14 +40,14 @@ RUN cargo build --profile release-opt
# Python builder # Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.4.0 ARG PYTORCH_VERSION=2.4.0
ARG PYTHON_VERSION=3.10 ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1 ARG CUDA_VERSION=12.4
ARG MAMBA_VERSION=24.3.0-0 ARG MAMBA_VERSION=24.3.0-0
ARG CUDA_CHANNEL=nvidia ARG CUDA_CHANNEL=nvidia
ARG INSTALL_CHANNEL=pytorch ARG INSTALL_CHANNEL=pytorch
@ -88,6 +88,7 @@ RUN case ${TARGETPLATFORM} in \
FROM pytorch-install AS kernel-builder FROM pytorch-install AS kernel-builder
ARG MAX_JOBS=8 ARG MAX_JOBS=8
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build cmake \ ninja-build cmake \
@ -118,29 +119,29 @@ FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllama_kernels/ . COPY server/exllama_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN python setup.py build
# Build Transformers exllama kernels # Build Transformers exllama kernels
FROM kernel-builder AS exllamav2-kernels-builder FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/exllamav2_kernels/ . COPY server/Makefile-exllamav2/ Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build RUN make build-exllamav2
# Build Transformers awq kernels # Build Transformers awq kernels
FROM kernel-builder AS awq-kernels-builder FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-awq Makefile COPY server/Makefile-awq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq RUN make build-awq
# Build eetq kernels # Build eetq kernels
FROM kernel-builder AS eetq-kernels-builder FROM kernel-builder AS eetq-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-eetq Makefile COPY server/Makefile-eetq Makefile
# Build specific version of transformers # Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq RUN make build-eetq
# Build Lorax Punica kernels # Build Lorax Punica kernels
FROM kernel-builder AS lorax-punica-builder FROM kernel-builder AS lorax-punica-builder
@ -183,6 +184,12 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile COPY server/Makefile-selective-scan Makefile
RUN make build-all RUN make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN make install-flashinfer
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
@ -191,7 +198,7 @@ ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda CONDA_PREFIX=/opt/conda
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
@ -221,11 +228,13 @@ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /
# Copy build artifacts from exllama kernels builder # Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from awq kernels builder # Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from eetq kernels builder # Copy build artifacts from eetq kernels builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from lorax punica kernels builder
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from fbgemm builder # Copy build artifacts from fbgemm builder
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from vllm builder # Copy build artifacts from vllm builder
@ -233,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
# Copy build artifacts from mamba builder # Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/
# Install flash-attention dependencies # Install flash-attention dependencies
RUN pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
@ -248,6 +258,9 @@ RUN cd server && \
pip install nvidia-nccl-cu12==2.22.3 pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
# This is needed because exl2 tries to load flash-attn
# And fails with our builds.
ENV EXLLAMA_NO_FLASH_ATTN=1
# Deps before the binaries # Deps before the binaries
# The binaries change on every build given we burn the SHA into them # The binaries change on every build given we burn the SHA into them

View File

@ -1,5 +1,5 @@
# Rust builder # Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -167,7 +167,7 @@ RUN python setup.py build
FROM base AS base-copy FROM base AS base-copy
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80

View File

@ -1,6 +1,6 @@
ARG PLATFORM=xpu ARG PLATFORM=xpu
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
@ -57,7 +57,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
# Text Generation Inference base env # Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 PORT=80
@ -106,7 +106,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
g++ \ g++ \
git \ git \
wget \ wget \
cmake cmake \
libnuma-dev
ENV HUGGINGFACE_HUB_CACHE=/data \ ENV HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \ HF_HUB_ENABLE_HF_TRANSFER=1 \
@ -135,7 +136,7 @@ RUN conda install -c conda-forge gperftools mkl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
RUN pip install triton RUN pip install triton numa
WORKDIR /usr/src WORKDIR /usr/src
@ -147,16 +148,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
ENV KMP_BLOCKTIME=1
ENV KMP_TPAUSE=0
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
# Install server # Install server
COPY proto proto COPY proto proto

View File

@ -13,7 +13,7 @@
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational"> <img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
</a> </a>
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
to power Hugging Chat, the Inference API and Inference Endpoint. to power Hugging Chat, the Inference API and Inference Endpoint.
</div> </div>
@ -42,12 +42,15 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
- Tensor Parallelism for faster inference on multiple GPUs - Tensor Parallelism for faster inference on multiple GPUs
- Token streaming using Server-Sent Events (SSE) - Token streaming using Server-Sent Events (SSE)
- Continuous batching of incoming requests for increased total throughput - Continuous batching of incoming requests for increased total throughput
- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures - Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
- Quantization with : - Quantization with :
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [GPT-Q](https://arxiv.org/abs/2210.17323) - [GPT-Q](https://arxiv.org/abs/2210.17323)
- [EETQ](https://github.com/NetEase-FuXi/EETQ) - [EETQ](https://github.com/NetEase-FuXi/EETQ)
- [AWQ](https://github.com/casper-hansen/AutoAWQ) - [AWQ](https://github.com/casper-hansen/AutoAWQ)
- [Marlin](https://github.com/IST-DASLab/marlin)
- [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
@ -92,6 +95,29 @@ curl 127.0.0.1:8080/generate_stream \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:3000/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
@ -120,7 +146,7 @@ For example, if you want to serve the gated Llama V2 model variants:
or with Docker: or with Docker:
```shell ```shell
model=meta-llama/Llama-2-7b-chat-hf model=meta-llama/Meta-Llama-3.1-8B-Instruct
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
token=<your cli READ token> token=<your cli READ token>
@ -232,7 +258,7 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
### Quantization ### Quantization
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:
```shell ```shell
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
@ -240,6 +266,8 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantiz
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. 4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).
## Develop ## Develop
```shell ```shell

17
_server.nix Normal file
View File

@ -0,0 +1,17 @@
{
mkPoetryApplication,
pkg-config,
protobuf,
openssl,
}:
mkPoetryApplication {
# name = "text-generation-server";
projectDir = ./server;
# nativeBuildInputs = [ pkg-config ];
# buildInputs = [ openssl.dev protobuf ];
}

View File

@ -153,9 +153,12 @@ impl Client {
}), }),
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Most request will have that
add_special_tokens: true,
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -221,6 +221,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 10,
add_special_tokens: true,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
@ -244,6 +245,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -8,17 +8,18 @@ homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
async-stream = "0.3" async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] } tokenizers = { version = "0.19", features = ["hf-hub"] }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.15"
clap = { version = "4.5", features = ["derive"] }
thiserror = "1.0.62" thiserror = "1.0.62"
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.24" tracing-opentelemetry = "0.24"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
log = { version = "0.4", features = [] } parking_lot = "0.12"
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -3,7 +3,7 @@ ARG OMPI_VERSION="4.1.6"
# Build dependencies resolver stage # Build dependencies resolver stage
FROM lukemathwalker/cargo-chef:latest AS chef FROM lukemathwalker/cargo-chef:latest AS chef
WORKDIR /usr/src/text-generation-inference WORKDIR /usr/src/text-generation-inference/backends/trtllm
FROM chef AS planner FROM chef AS planner
COPY . . COPY . .
@ -42,7 +42,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE
mkdir /usr/src/mpi && \ mkdir /usr/src/mpi && \
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
cd /usr/src/mpi && \ cd /usr/src/mpi && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \
make -j all && \ make -j all && \
make install && \ make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
@ -66,7 +66,7 @@ ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef RUN cargo install cargo-chef
# Cache dependencies # Cache dependencies
COPY --from=planner /usr/src/text-generation-inference/recipe.json . COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --release --recipe-path recipe.json
# Build actual TGI # Build actual TGI
@ -79,7 +79,8 @@ COPY . .
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm cd backends/trtllm && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
WORKDIR /usr/local/tgi/bin WORKDIR /usr/local/tgi/bin

View File

@ -12,12 +12,13 @@ use cxx::UniquePtr;
use log::{error, warn}; use log::{error, warn};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{sleep, Instant}; use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level}; use tracing::{instrument, span, Level};
// use tokio::sync::RwLock;
use parking_lot::RwLock;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};

View File

@ -1,12 +1,10 @@
use clap::Parser;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use clap::Parser;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::{FromPretrainedParameters, Tokenizer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -160,6 +158,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
messages_api_enabled, messages_api_enabled,
true, true,
max_client_batch_size, max_client_batch_size,
false,
false,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -33,9 +33,16 @@ rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] } reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true} tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
@ -43,9 +50,11 @@ tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = [
minijinja = { version = "2.0.2" } "opentelemetry-otlp",
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } ] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3" regex = "1.10.3"
once_cell = "1.19.0" once_cell = "1.19.0"
@ -59,8 +68,16 @@ tower = "^0.4"
tonic-build = "0.10.1" tonic-build = "0.10.1"
prost-build = "0.12.1" prost-build = "0.12.1"
[dev-dependencies]
criterion = "0.3"
itertools = "0.13"
[features] [features]
default = ["ngrok"] default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"] ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"] google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"] kserve = ["text-generation-router/kserve"]
[[bench]]
name = "prefix_cache"
harness = false

View File

@ -0,0 +1,47 @@
use std::sync::Arc;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::Rng;
use text_generation_router_v3::block_allocator::Allocator;
use text_generation_router_v3::radix::RadixAllocator;
fn prefix_cache_benchmark(c: &mut Criterion) {
// let prefixes: Vec<Vec<u32>> = (0..8192)
// .chunks(256)
// .into_iter()
// .map(|c| c.collect())
// .collect();
let mut cache = RadixAllocator::new(1, 262144, None);
c.bench_function("Radix allocator", |b| {
b.iter_batched(
|| {
//prefixes
// .choose_multiple(&mut rand::thread_rng(), 5)
// .fold(Vec::new(), |mut v, s| {
// v.extend(s);
// v
// })
(0..7936)
.map(|_| rand::thread_rng().gen_range(0..1024))
.collect::<Vec<u32>>()
},
|prefill| {
let alloc = cache.allocate(
prefill.len() as u32 + 13,
Some(Arc::new(black_box(prefill))),
);
if let Some(alloc) = alloc {
cache.free(alloc.blocks.clone(), alloc.allocation_id);
}
},
criterion::BatchSize::SmallInput,
);
});
}
criterion_group!(benches, prefix_cache_benchmark);
criterion_main!(benches);

View File

@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token}; use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
@ -35,16 +35,20 @@ impl BackendV3 {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
) -> Self { ) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let prefix_caching =
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
} else { let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
false let attention: String = std::env::var("ATTENTION").expect("attention env var");
};
let block_size = if flashdecoding { 256 } else { 16 }; let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let queue = Queue::new( let queue = Queue::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -168,7 +172,8 @@ pub(crate) async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue

View File

@ -1,21 +1,31 @@
use std::cmp::min; use std::sync::Arc;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocation { pub struct BlockAllocation {
pub allocation_id: u64,
pub blocks: Vec<u32>, pub blocks: Vec<u32>,
pub slots: Vec<u32>, pub slots: Vec<u32>,
block_allocator: BlockAllocator,
/// Prefix that was cached and for which the KV does not have to
/// be recomputed.
pub prefix_len: u32,
pub(crate) block_allocator: Option<BlockAllocator>,
} }
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone()) if let Some(block_allocator) = self.block_allocator.as_mut() {
block_allocator.free(self.blocks.clone(), self.allocation_id)
}
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct BlockAllocator { pub struct BlockAllocator {
/// Channel to communicate with the background task /// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>, block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
} }
@ -24,6 +34,7 @@ impl BlockAllocator {
pub(crate) fn new( pub(crate) fn new(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
) -> Self { ) -> Self {
// Create channel // Create channel
@ -33,6 +44,7 @@ impl BlockAllocator {
tokio::spawn(block_allocator_task( tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size, max_batch_total_tokens / block_size,
block_size, block_size,
prefix_caching,
window_size, window_size,
receiver, receiver,
)); ));
@ -42,28 +54,32 @@ impl BlockAllocator {
} }
} }
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> { pub(crate) async fn allocate(
&self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel(); let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Allocate { .send(BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
}) })
.unwrap(); .unwrap();
response_receiver response_receiver.await.unwrap().map(|mut allocation| {
.await allocation.block_allocator = Some(self.clone());
.unwrap() allocation
.map(|(blocks, slots)| BlockAllocation { })
blocks,
slots,
block_allocator: self.clone(),
})
} }
pub(crate) fn free(&self, blocks: Vec<u32>) { pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
self.block_allocator self.block_allocator
.send(BlockAllocatorCommand::Free { blocks }) .send(BlockAllocatorCommand::Free {
allocation_id,
blocks,
})
.unwrap(); .unwrap();
} }
} }
@ -71,54 +87,29 @@ impl BlockAllocator {
async fn block_allocator_task( async fn block_allocator_task(
blocks: u32, blocks: u32,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>, mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) { ) {
// Block 0 is reserved for health checks let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
let mut free_blocks: Vec<u32> = (1..blocks).collect(); Box::new(RadixAllocator::new(block_size, blocks, window_size))
} else {
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
};
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), BlockAllocatorCommand::Free {
blocks,
allocation_id,
} => allocator.free(blocks, allocation_id),
BlockAllocatorCommand::Allocate { BlockAllocatorCommand::Allocate {
tokens, tokens,
prefill_tokens,
response_sender, response_sender,
} => { } => {
// Apply window size response_sender
let (required_blocks, repeats) = { .send(allocator.allocate(tokens, prefill_tokens))
let (tokens, repeats) = match window_size { .unwrap();
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
} }
} }
} }
@ -128,9 +119,91 @@ async fn block_allocator_task(
enum BlockAllocatorCommand { enum BlockAllocatorCommand {
Free { Free {
blocks: Vec<u32>, blocks: Vec<u32>,
allocation_id: u64,
}, },
Allocate { Allocate {
tokens: u32, tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>, prefill_tokens: Option<Arc<Vec<u32>>>,
response_sender: oneshot::Sender<Option<BlockAllocation>>,
}, },
} }
pub trait Allocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation>;
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
}
pub struct SimpleAllocator {
free_blocks: Vec<u32>,
block_size: u32,
window_size: Option<u32>,
}
impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator {
block_size,
// Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(),
window_size,
}
}
}
impl Allocator for SimpleAllocator {
fn allocate(
&mut self,
tokens: u32,
_prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match self.window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = core::cmp::min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
if required_blocks > self.free_blocks.len() as u32 {
None
} else {
let blocks = self
.free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize);
let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some(BlockAllocation {
allocation_id: 0,
blocks,
slots,
prefix_len: 0,
block_allocator: None,
})
}
}
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
self.free_blocks.extend(blocks)
}
}

View File

@ -149,6 +149,7 @@ impl Client {
requests.push(Request { requests.push(Request {
id: 0, id: 0,
inputs, inputs,
add_special_tokens: true,
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: input_chunks, chunks: input_chunks,
}), }),
@ -157,6 +158,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention // Blocks and slots will be set on the server side if we use paged attention
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -222,6 +222,7 @@ impl Health for ShardedClient {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 10,
add_special_tokens: true,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
@ -245,6 +246,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -1,7 +1,8 @@
mod backend; mod backend;
mod block_allocator; pub mod block_allocator;
mod client; mod client;
mod queue; mod queue;
pub mod radix;
use crate::client::{ClientError, ShardedClient}; use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3; pub(crate) use backend::BackendV3;

View File

@ -150,6 +150,14 @@ async fn main() -> Result<(), RouterError> {
} }
} }
if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` must be > 0".to_string(),
));
}
}
let (backend, _backend_info) = connect_backend( let (backend, _backend_info) = connect_backend(
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,

View File

@ -46,6 +46,7 @@ impl Queue {
pub(crate) fn new( pub(crate) fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -57,6 +58,7 @@ impl Queue {
tokio::spawn(queue_task( tokio::spawn(queue_task(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -109,6 +111,7 @@ impl Queue {
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
@ -117,6 +120,7 @@ async fn queue_task(
let mut state = State::new( let mut state = State::new(
requires_padding, requires_padding,
block_size, block_size,
prefix_caching,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
@ -176,12 +180,19 @@ impl State {
fn new( fn new(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
prefix_caching: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding) let block_allocator = (!requires_padding).then(|| {
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); BlockAllocator::new(
max_batch_total_tokens,
block_size,
prefix_caching,
window_size,
)
});
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
@ -226,13 +237,20 @@ impl State {
} }
} }
if let Some(max_size) = max_size {
if max_size == 0 {
tracing::debug!("No capacity");
return None;
}
}
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
// Create span for this batch to add context to inference calls // Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current()); next_batch_span.follows_from(Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len()); let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries = let mut batch_entries =
@ -298,7 +316,15 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator.allocate(tokens).await { // If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
None
} else {
entry.request.input_ids.clone()
};
match block_allocator.allocate(tokens, input_ids).await {
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
@ -324,11 +350,12 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation { let (blocks, slots, prefix_len) = match &block_allocation {
None => (Vec::new(), Vec::new()), None => (Vec::new(), Vec::new(), 0),
Some(block_allocation) => ( Some(block_allocation) => (
block_allocation.blocks.clone(), block_allocation.blocks.clone(),
block_allocation.slots.clone(), block_allocation.slots.clone(),
block_allocation.prefix_len,
), ),
}; };
@ -356,6 +383,7 @@ impl State {
}), }),
inputs: entry.request.inputs.chunks_to_string(), inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
add_special_tokens: entry.request.add_special_tokens,
parameters: Some(NextTokenChooserParameters::from( parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(), entry.request.parameters.clone(),
)), )),
@ -365,6 +393,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks, blocks,
slots, slots,
prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
}); });
// Set batch_time // Set batch_time
@ -473,6 +502,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use super::*; use super::*;
use tracing::info_span; use tracing::info_span;
@ -485,7 +516,9 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: vec![], inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0, input_length: 0,
add_special_tokens: true,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
parameters: ValidParameters { parameters: ValidParameters {
@ -520,7 +553,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -536,7 +569,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -544,7 +577,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -576,7 +609,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0, 16); let mut state = State::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -596,7 +629,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2); let mut state = State::new(false, 1, false, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
@ -629,14 +662,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -644,7 +677,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -677,7 +710,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -693,7 +726,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -718,7 +751,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16); let queue = Queue::new(false, 1, false, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -737,7 +770,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0, 16); let queue = Queue::new(false, 1, false, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

831
backends/v3/src/radix.rs Normal file
View File

@ -0,0 +1,831 @@
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};
pub struct RadixAllocator {
allocation_id: u64,
allocations: HashMap<u64, RadixAllocation>,
cache_blocks: RadixTrie,
/// Blocks that are immediately available for allocation.
free_blocks: Vec<u32>,
#[allow(dead_code)]
// This isn't used because the prefix need to match without the windowing
// mecanism. This at worst is overallocating, not necessarily being wrong.
window_size: Option<u32>,
block_size: u32,
}
impl RadixAllocator {
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
RadixAllocator {
allocation_id: 0,
allocations: HashMap::new(),
cache_blocks: RadixTrie::new(block_size as usize),
// Block 0 is reserved for health checks.
free_blocks: (1..n_blocks).collect(),
window_size,
block_size,
}
}
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
if self.free_blocks.len() < n_blocks_needed {
// This is a bit annoying, we first extend the free list and then
// split it off again below. This is because we need to put it on
// the free list if we cannot allocate enough blocks. This is only
// temporary, the trie needs to be able to report whether it can
// allocate the requested amount. Just not implemented yet.
self.free_blocks.extend(
self.cache_blocks
.evict(n_blocks_needed - self.free_blocks.len()),
);
}
if self.free_blocks.len() >= n_blocks_needed {
Some(
self.free_blocks
.split_off(self.free_blocks.len() - n_blocks_needed),
)
} else {
None
}
}
}
// Allocator trait
impl Allocator for RadixAllocator {
fn allocate(
&mut self,
tokens: u32,
prefill_tokens: Option<Arc<Vec<u32>>>,
) -> Option<BlockAllocation> {
let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
// Even if this allocation fails below, we need to increase he
// refcount to ensure that the prefix that was found is not evicted.
node_id
} else {
self.cache_blocks.root_id()
};
self.cache_blocks
.incref(prefix_node)
.expect("Failed to increment refcount");
let prefix_len = blocks.len() * self.block_size as usize;
let suffix_len = tokens - prefix_len as u32;
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
match self.alloc_or_reclaim(suffix_blocks as usize) {
Some(suffix_blocks) => blocks.extend(suffix_blocks),
None => {
self.cache_blocks
.decref(prefix_node)
.expect("Failed to decrement refcount");
return None;
}
}
// 1:1 mapping of blocks and slots.
let slots = if self.block_size == 1 {
blocks.clone()
} else {
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
'slots: for block_id in &blocks {
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
slots.push(s);
if slots.len() as u32 == tokens {
break 'slots;
}
}
}
slots
};
let allocation = RadixAllocation {
prefix_node,
cached_prefix_len: prefix_len,
prefill_tokens: prefill_tokens.clone(),
};
tracing::debug!("Blocks {blocks:?}");
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);
Some(BlockAllocation {
allocation_id: self.allocation_id,
block_allocator: None,
blocks,
slots,
prefix_len: prefix_len as u32,
})
}
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
self.cache_blocks
.decref(allocation.prefix_node)
.expect("Failed to decrement refcount");
if let Some(prefill_tokens) = allocation.prefill_tokens {
let prefill_tokens = prefill_tokens.as_slice();
// If there are prefill tokens that did not come from the cache,
// add them to the cache.
if prefill_tokens.len() > allocation.cached_prefix_len {
let aligned =
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
if aligned > 0 {
let prefix_len = self
.cache_blocks
.insert(
&prefill_tokens[..aligned],
&blocks[..aligned / self.block_size as usize],
)
// Unwrap, failing is a programming error.
.expect("Failed to store prefill tokens");
// We can have a prefill with the following structure:
//
// |---| From the prefix cache.
// A B C D E F G
//|--------| Found in the trie during insertion.
//
// This means that while processing this request there was a
// partially overlapping request that had A..=E in its
// prefill. In this case we need to free the blocks D E.
if prefix_len > allocation.cached_prefix_len {
self.free_blocks.extend(
&blocks[allocation.cached_prefix_len / self.block_size as usize
..prefix_len / self.block_size as usize],
);
}
}
}
// Free non-prefill blocks.
self.free_blocks
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
} else {
self.free_blocks.extend(blocks);
}
}
}
struct RadixAllocation {
prefix_node: NodeId,
cached_prefix_len: usize,
prefill_tokens: Option<Arc<Vec<u32>>>,
}
// Radix trie that is heavily inspired by radix attention from sglang.
//
// The trie is optimized for prefix caching:
//
// - A normal radix trie stores discrete values. In this radix trie,
// inserting *abc* with value *xyz* will also enable lookup for
// *a* (*x*) and *ab* (*xy*).
// - As a result, every value is required to have the same length as
// the key.
// - We store additional information in each node, such as last access
// time and a reference count.
#[derive(Debug)]
pub enum TrieError {
InvalidNodeId,
RefCountUnderflow,
BlockTokenCountMismatch,
}
pub type NodeId = DefaultKey;
#[derive(Debug)]
pub struct RadixTrie {
/// Identifier of the root nod.
root: DefaultKey,
/// Leave node identifiers ordered by increasing recency.
leaves: BTreeSet<(u64, NodeId)>,
/// All trie nodes.
nodes: SlotMap<NodeId, TrieNode>,
/// Time as a monotonically increating counter to avoid the system
/// call that a real time lookup would require.
time: u64,
/// All blocks need to be aligned with this
block_size: usize,
}
impl RadixTrie {
/// Construct a new radix trie.
pub fn new(block_size: usize) -> Self {
let root = TrieNode::new(vec![], vec![], 0, None);
let mut nodes = SlotMap::new();
let root = nodes.insert(root);
RadixTrie {
leaves: BTreeSet::new(),
nodes,
root,
time: 0,
block_size,
}
}
/// Find the prefix of the given tokens.
///
/// The blocks corresponding to the part of the prefix that could be found
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
/// Returns the identifier of the trie node that contains the longest
/// prefix. The node identifier can be used by callers to e.g. increase its
/// reference count.
///
/// Using this method will update the access time of the traversed nodes.
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
self.time += 1;
self.find_(self.root, key, blocks)
}
/// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if let Some(&child_id) = node.children.get(&key[0]) {
self.update_access_time(child_id);
let child = self.nodes.get(child_id).expect("Invalid child identifier");
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
}
}
node_id
}
/// Decrease the reference count of a node.
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
// We don't care about refcounting for root, since it will never
// be evicted.
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
return Err(TrieError::RefCountUnderflow);
}
node.ref_count -= 1;
if node.ref_count == 0 {
self.leaves.insert((node.last_accessed, node_id));
}
Ok(())
}
/// Increase the reference count of a node.
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
if node_id == self.root {
return Ok(());
}
let node = self
.nodes
.get_mut(node_id)
.ok_or(TrieError::InvalidNodeId)?;
if node.ref_count == 0 {
self.leaves.remove(&(node.last_accessed, node_id));
}
node.ref_count += 1;
Ok(())
}
/// Evict `n_blocks` from the trie.
///
/// Returns the evicted blocks. When the length is less than `n_blocks`,
/// not enough blocks could beevicted.
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
// NOTE: we don't return Result here. If any of the unwrapping fails,
// it's a programming error in the trie implementation, not a user
// error caused by e.g. an invalid argument.
// TODO: add some bookkeeping in the future to check whether we can
// evict n_blocks and return `None` if we can't. We are now needlessly
// evicting prefixes from the cache in such a case.
let mut evicted = Vec::new();
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks - evicted.len();
let node = self.nodes.get(node_id).expect("Leave does not exist");
if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id);
evicted.extend(node.blocks);
if evicted.len() >= n_blocks {
break;
}
} else {
// The node has more blocks than needed, so we'll just remove
// the required number of blocks and leave the remaining blocks
// untouched.
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
node.key.truncate(node.blocks.len() - blocks_needed);
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
self.leaves.insert((last_access, node_id));
break;
}
}
evicted
}
/// Insert a prefill along with its blocks.
///
/// This method returns the length of the prefix that was already
/// in the trie. E.g. if the length is 10, this means that for
/// the first 10 elements of the tree **the blocks are not updated**.
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
self.time += 1;
let common = self.insert_(self.root, tokens, blocks)?;
Ok(common)
}
/// Insertion worker.
fn insert_(
&mut self,
node_id: NodeId,
tokens: &[u32],
blocks: &[u32],
) -> Result<usize, TrieError> {
// TODO: in the future we may want to check that the blocks match for
// the part of the prefix that is already in the trie to detect
// mismatches.
if tokens.len() != blocks.len() * self.block_size {
return Err(TrieError::BlockTokenCountMismatch);
}
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
self.update_access_time(child_id);
let child = self
.nodes
.get_mut(child_id)
// Unwrap here, since failure is a bug.
.expect("Child node does not exist");
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
// We are done, the prefix is already in the trie.
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
return Ok(shared_prefix_len);
}
// The node's prefix is a prefix of the insertion prefix.
if shared_prefix_len == child.key.len() {
return Ok(shared_prefix_len
+ self.insert_(
child_id,
&tokens[shared_prefix_len..],
&blocks[shared_prefix_len / self.block_size..],
)?);
}
// The node's prefix and the insertion prefix only match partially,
// split the node to just contain the matching part. Then insert the
// remainder of the prefix into the node again
let child_id = self.split_node(child_id, shared_prefix_len);
let key = &tokens[shared_prefix_len..];
let blocks = &blocks[shared_prefix_len / self.block_size..];
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
} else {
self.add_node(node_id, tokens, blocks);
Ok(0)
}
}
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
// We have to make the current node a child to ensure that its
// properties and node id stay the same.
// This funcion unwraps, an invalid node_id is a programming error.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
let mut parent_key = node.key.split_off(prefix_len);
let mut parent_blocks = node.blocks.split_off(prefix_len);
// Move first part of the prefix to the parent. We swap to avoid
// an allocation + copy for both splits of the key/blocks.
std::mem::swap(&mut node.key, &mut parent_key);
std::mem::swap(&mut node.blocks, &mut parent_blocks);
let node_key = node.key[0];
let grandparent_id = node.parent.expect("Node does not have a parent");
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
self.add_node_to_parent(parent_id, node_key, node_id);
// Reborrow to make the borrow checker happy.
let node = self
.nodes
.get_mut(node_id)
.expect("Node to-be split does not exist");
node.parent = Some(parent_id);
parent_id
}
/// Create a node and add it to the parent.
fn add_node(
&mut self,
parent_id: NodeId,
key: impl Into<Vec<u32>>,
blocks: impl Into<Vec<u32>>,
) -> NodeId {
let key = key.into();
let blocks = blocks.into();
let first = key[0];
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
let child_id = self.nodes.insert(child);
self.add_node_to_parent(parent_id, first, child_id);
self.leaves.insert((self.time, child_id));
child_id
}
/// Add a node to the parent.
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
if parent.children.insert(first, child_id).is_none() {
// Only increase reference count if child does not replace another child.
self.incref(parent_id)
.expect("Failed to increase parent refcount");
}
}
/// Remove a node from the trie.
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.remove(node_id).expect("Unknown node");
let parent_id = node.parent.expect("Attempted to remove root node");
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
parent.children.remove(&node.key[0]);
self.decref(parent_id)
.expect("Failed to decrease parent refcount");
self.nodes.remove(node_id);
node
}
fn update_access_time(&mut self, node_id: NodeId) {
// Unwrap here, passing in an unknown id is a programming error.
let node = self.nodes.get_mut(node_id).expect("Unknown node");
// Update the ordered leaves set if the node is a leave.
if self.leaves.remove(&(node.last_accessed, node_id)) {
self.leaves.insert((self.time, node_id));
}
node.last_accessed = self.time;
}
#[allow(dead_code)]
#[doc(hidden)]
/// Print debugging output for the trie.
///
/// In contrast to `Debug` nicely formatted.
pub fn print_debug(&self) {
self.print_debug_(self.root, 0);
}
fn print_debug_(&self, node_id: NodeId, indent: usize) {
let node = &self.nodes[node_id];
eprintln!(
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
" ".repeat(indent),
node_id,
node.key,
node.blocks,
node.ref_count,
node.last_accessed,
node.parent,
node.children
);
for child_id in self.nodes[node_id].children.values() {
self.print_debug_(*child_id, indent + 2);
}
}
pub(crate) fn root_id(&self) -> DefaultKey {
self.root
}
}
/// Trie node.
#[derive(Debug)]
struct TrieNode {
blocks: Vec<u32>,
children: HashMap<u32, NodeId>,
key: Vec<u32>,
last_accessed: u64,
parent: Option<NodeId>,
ref_count: usize,
}
impl TrieNode {
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
TrieNode {
children: HashMap::new(),
key,
blocks,
last_accessed,
parent,
ref_count: 0,
}
}
}
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
(full / block_size) * block_size
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
#[test]
fn allocator_block_size() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_block_size_non_aligned() {
let mut cache = RadixAllocator::new(2, 12, None);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 2);
}
#[test]
fn allocator_reuses_prefixes() {
let mut cache = RadixAllocator::new(1, 12, None);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.blocks, allocation.slots);
assert_eq!(allocation.prefix_len, 0);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation.prefix_len, 4);
}
#[test]
fn allocator_collects_older_prefixes_first() {
let mut cache = RadixAllocator::new(1, 7, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation1.prefix_len, 0);
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
assert_eq!(allocation2.blocks, vec![1, 2]);
assert_eq!(allocation2.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// We should get the blocks of the first allocation, since they are more recent.
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
assert_eq!(allocation3.prefix_len, 0);
}
#[test]
fn allocator_frees_fully_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 10, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation3.prefix_len, 4);
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
assert_eq!(cache.free_blocks.len(), 5);
}
#[test]
fn allocator_frees_partially_overlapping_prefills() {
let mut cache = RadixAllocator::new(1, 20, None);
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
assert_eq!(allocation1.prefix_len, 0);
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
let allocation2 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
assert_eq!(allocation2.prefix_len, 2);
let allocation3 = cache
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
assert_eq!(allocation3.prefix_len, 2);
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
assert_eq!(cache.free_blocks.len(), 11);
let allocation4 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
.unwrap();
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
assert_eq!(allocation4.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
let allocation5 = cache
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
.unwrap();
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
assert_eq!(allocation5.prefix_len, 6);
assert_eq!(cache.free_blocks.len(), 11);
}
#[test]
fn trie_insertions_have_correct_prefix_len() {
let mut trie = RadixTrie::new(1);
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
// Already exists.
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap(),
4
);
}
#[test]
fn trie_insertions_block_size() {
let mut trie = RadixTrie::new(2);
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
// Already exists.
// But needs to be block_size aligned
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
// Completely new at root-level
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
// Contains full prefix, but longer.
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
// Shares partial prefix, we need a split.
assert_eq!(
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
.unwrap(),
2
);
}
#[test]
fn trie_get_returns_correct_blocks() {
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
let mut blocks = Vec::new();
trie.find(&[0], &mut blocks);
assert_eq!(blocks, vec![0]);
blocks.clear();
trie.find(&[0, 1, 2], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2]);
blocks.clear();
trie.find(&[1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
}
#[test]
fn trie_evict_removes_correct_blocks() {
let mut trie = RadixTrie::new(1);
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
.unwrap();
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
let mut blocks = Vec::new();
// Remove less than the leave blocks.
assert_eq!(trie.evict(1), vec![7]);
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
// Refresh other leaf.
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
trie.find(&[1, 2, 3], &mut blocks);
// Remove the leave blocks exactly.
assert_eq!(trie.evict(2), vec![5, 6]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
assert_eq!(blocks, vec![0, 1, 2, 3]);
trie.find(&[1, 2, 3], &mut blocks);
// Remove more than the leave blocks.
assert_eq!(trie.evict(3), vec![4, 3, 2]);
blocks.clear();
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
assert_eq!(blocks, vec![0, 1]);
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
}

View File

@ -148,6 +148,7 @@ async fn prefill(
}), }),
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length, truncate: sequence_length,
add_special_tokens: true,
parameters: Some(parameters.clone()), parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length, max_new_tokens: decode_length,
@ -157,6 +158,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();

View File

@ -757,7 +757,12 @@ class AsyncClient:
continue continue
payload = byte_payload.decode("utf-8") payload = byte_payload.decode("utf-8")
if payload.startswith("data:"): if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) payload_data = (
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
)
if payload_data == "[DONE]":
break
json_payload = json.loads(payload_data)
try: try:
response = ChatCompletionChunk(**json_payload) response = ChatCompletionChunk(**json_payload)
yield response yield response

View File

@ -556,6 +556,37 @@
} }
} }
} }
},
"/v1/models": {
"get": {
"tags": [
"Text Generation Inference"
],
"summary": "Get model info",
"operationId": "openai_get_model_info",
"responses": {
"200": {
"description": "Served model info",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ModelInfo"
}
}
}
},
"404": {
"description": "Model not found",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
}
}
}
}
}
}
} }
}, },
"components": { "components": {
@ -819,6 +850,13 @@
"example": "1.0", "example": "1.0",
"nullable": true "nullable": true
}, },
"guideline": {
"type": "string",
"description": "A guideline to be used in the chat_template",
"default": "null",
"example": "null",
"nullable": true
},
"logit_bias": { "logit_bias": {
"type": "array", "type": "array",
"items": { "items": {
@ -917,7 +955,7 @@
"tool_prompt": { "tool_prompt": {
"type": "string", "type": "string",
"description": "A prompt to be appended before the tools", "description": "A prompt to be appended before the tools",
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
"nullable": true "nullable": true
}, },
"tools": { "tools": {
@ -1740,6 +1778,35 @@
} }
] ]
}, },
"ModelInfo": {
"type": "object",
"required": [
"id",
"object",
"created",
"owned_by"
],
"properties": {
"created": {
"type": "integer",
"format": "int64",
"example": 1686935002,
"minimum": 0
},
"id": {
"type": "string",
"example": "gpt2"
},
"object": {
"type": "string",
"example": "model"
},
"owned_by": {
"type": "string",
"example": "openai"
}
}
},
"OutputMessage": { "OutputMessage": {
"oneOf": [ "oneOf": [
{ {
@ -1817,7 +1884,8 @@
"type": "object", "type": "object",
"required": [ "required": [
"finish_reason", "finish_reason",
"generated_tokens" "generated_tokens",
"input_length"
], ],
"properties": { "properties": {
"finish_reason": { "finish_reason": {
@ -1829,6 +1897,12 @@
"example": 1, "example": 1,
"minimum": 0 "minimum": 0
}, },
"input_length": {
"type": "integer",
"format": "int32",
"example": 1,
"minimum": 0
},
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",

View File

@ -17,8 +17,6 @@
title: Installation from source title: Installation from source
- local: supported_models - local: supported_models
title: Supported Models and Hardware title: Supported Models and Hardware
- local: messages_api
title: Messages API
- local: architecture - local: architecture
title: Internal Architecture title: Internal Architecture
- local: usage_statistics - local: usage_statistics
@ -33,8 +31,6 @@
title: Serving Private & Gated Models title: Serving Private & Gated Models
- local: basic_tutorials/using_cli - local: basic_tutorials/using_cli
title: Using TGI CLI title: Using TGI CLI
- local: basic_tutorials/launcher
title: All TGI CLI options
- local: basic_tutorials/non_core_models - local: basic_tutorials/non_core_models
title: Non-core Model Serving title: Non-core Model Serving
- local: basic_tutorials/safety - local: basic_tutorials/safety
@ -48,6 +44,14 @@
- local: basic_tutorials/train_medusa - local: basic_tutorials/train_medusa
title: Train Medusa title: Train Medusa
title: Tutorials title: Tutorials
- sections:
- local: reference/launcher
title: All TGI CLI options
- local: reference/metrics
title: Exported Metrics
- local: reference/api_reference
title: API Reference
title: Reference
- sections: - sections:
- local: conceptual/streaming - local: conceptual/streaming
title: Streaming title: Streaming
@ -64,7 +68,7 @@
- local: conceptual/speculation - local: conceptual/speculation
title: Speculation (Medusa, ngram) title: Speculation (Medusa, ngram)
- local: conceptual/guidance - local: conceptual/guidance
title: How Guidance Works (via outlines title: How Guidance Works (via outlines)
- local: conceptual/lora - local: conceptual/lora
title: LoRA (Low-Rank Adaptation) title: LoRA (Low-Rank Adaptation)

View File

@ -1,81 +1,125 @@
# Consuming Text Generation Inference # Consuming Text Generation Inference
There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens.
For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference).
You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models.
## curl ## curl
After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec:
```bash
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json'
```
For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes.
```bash ```bash
curl 127.0.0.1:8080/generate \ curl 127.0.0.1:8080/generate \
-X POST \ -X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -d '{
"inputs":"What is Deep Learning?",
"parameters":{
"max_new_tokens":20
}
}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```
## Python
## Inference Client ### Inference Client
[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface.
You can simply install `huggingface-hub` package with pip.
Install `huggingface_hub` package via pip.
```bash ```bash
pip install huggingface-hub pip install huggingface_hub
``` ```
Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python
```python ```python
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
client = InferenceClient(model="http://127.0.0.1:8080") client = InferenceClient(
client.text_generation(prompt="Write a code for snake game") base_url="http://localhost:8080/v1/",
)
output = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count to 10"},
],
stream=True,
max_tokens=1024,
)
for chunk in output:
print(chunk.choices[0].delta.content)
``` ```
You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility).
There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
### OpenAI Client
You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI.
Install the OpenAI Python package via pip.
```bash
pip install openai
```
```python ```python
for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): from openai import OpenAI
print(token)
# init the client but point it to TGI
client = OpenAI(
base_url="http://localhost:8080/v1/",
api_key="-"
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is deep learning?"}
],
stream=True
)
# iterate and print stream
for message in chat_completion:
print(message)
``` ```
Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. ## UI
```python ### Gradio
output = client.text_generation(prompt="Meaning of life is", details=True)
print(output)
# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..))
```
You can see how to stream below.
```python
output = client.text_generation(prompt="Meaning of life is", stream=True, details=True)
print(next(iter(output)))
# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None)
```
You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
## ChatUI
ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.
To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.
```
{
// rest of the model config here
"endpoints": [{"url": "https://HOST:PORT/generate_stream"}]
}
```
![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)
## Gradio
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference
import gradio as gr import gradio as gr
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
client = InferenceClient(model="http://127.0.0.1:8080") client = InferenceClient(base_url="http://127.0.0.1:8080")
def inference(message, history): def inference(message, history):
partial_message = "" partial_message = ""
for token in client.text_generation(message, max_new_tokens=20, stream=True): output = client.chat.completions.create(
partial_message += token messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": message},
],
stream=True,
max_tokens=1024,
)
for chunk in output:
partial_message += chunk.choices[0].delta.content
yield partial_message yield partial_message
gr.ChatInterface( gr.ChatInterface(
inference, inference,
chatbot=gr.Chatbot(height=300), chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", description="This is the demo for Gradio UI consuming TGI endpoint.",
title="Gradio 🤝 TGI", title="Gradio 🤝 TGI",
examples=["Are tomatoes vegetables?"], examples=["Are tomatoes vegetables?"],
retry_btn="Retry", retry_btn="Retry",
@ -110,20 +163,7 @@ gr.ChatInterface(
).queue().launch() ).queue().launch()
``` ```
The UI looks like this 👇 You can check out the UI and try the demo directly here 👇
<div class="flex justify-center">
<img
class="block dark:hidden"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
/>
<img
class="hidden dark:block"
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
/>
</div>
You can try the demo directly here 👇
<div class="block dark:hidden"> <div class="block dark:hidden">
<iframe <iframe
@ -141,15 +181,19 @@ You can try the demo directly here 👇
</div> </div>
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
```python
def inference(message, history):
return client.text_generation(message, max_new_tokens=20)
```
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast). You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
## API documentation ### ChatUI
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference). [ChatUI](https://github.com/huggingface/chat-ui) is an open-source interface built for consuming LLMs. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.
To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.
```
{
// rest of the model config here
"endpoints": [{"url": "https://HOST:PORT/generate_stream"}]
}
```
![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)

View File

@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
## Quantization ## Quantization
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization) TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
## RoPE Scaling ## RoPE Scaling

View File

@ -4,7 +4,7 @@ Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-
These feature are available starting from version `1.4.3`. They are accessible via the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! These feature are available starting from version `1.4.3`. They are accessible via the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `/chat/completions` endpoint._ _note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `v1/chat/completions` endpoint._
## How it works ## How it works
@ -157,7 +157,12 @@ from huggingface_hub import InferenceClient
client = InferenceClient("http://localhost:3000") client = InferenceClient("http://localhost:3000")
regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)" section_regex = "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
regexp = f"HELLO\.{section_regex}\.WORLD\.{section_regex}"
# This is a more realistic example of an ip address regex
# regexp = f"{section_regex}\.{section_regex}\.{section_regex}\.{section_regex}"
resp = client.text_generation( resp = client.text_generation(
f"Whats Googles DNS? Please use the following regex: {regexp}", f"Whats Googles DNS? Please use the following regex: {regexp}",
@ -170,7 +175,7 @@ resp = client.text_generation(
print(resp) print(resp)
# 7.1.1.1 # HELLO.255.WORLD.255
``` ```

View File

@ -84,7 +84,7 @@ print(chat)
``` ```
or with OpenAi's library: or with OpenAI's [client library](https://github.com/openai/openai-python):
```python ```python
from openai import OpenAI from openai import OpenAI

View File

@ -36,6 +36,18 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
``` ```
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
```bash
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
```
note it's possible to mix adapter_ids with adapter_id=adapter_path e.g.
```bash
LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/
```
In the server logs, you will see the following message: In the server logs, you will see the following message:
```txt ```txt

View File

@ -1,6 +1,40 @@
# Quantization # Quantization
TGI offers GPTQ and bits-and-bytes quantization to quantize large language models. TGI offers many quantization schemes to run LLMs effectively and fast based on your use-case. TGI supports GPTQ, AWQ, bits-and-bytes, EETQ, Marlin, EXL2 and fp8 quantization.
To leverage GPTQ, AWQ, Marlin and EXL2 quants, you must provide pre-quantized weights. Whereas for bits-and-bytes, EETQ and fp8, weights are quantized by TGI on the fly.
We recommend using the official quantization scripts for creating your quants:
1. [AWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quantize.py)
2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py)
3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md)
For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.
## Quantization with bitsandbytes, EETQ & fp8
bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.
8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
Similarly you can use pass you can pass `--quantize eetq` or `--quantize fp8` for respective quantization schemes.
In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset.
## Quantization with GPTQ ## Quantization with GPTQ
@ -36,24 +70,3 @@ You can learn more about the quantization options by running `text-generation-se
If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration). If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration).
You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf). You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).
## Quantization with bitsandbytes
bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.
8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
```
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
```bash
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
```
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).

View File

@ -48,34 +48,29 @@ To stream tokens with `InferenceClient`, simply pass `stream=True` and iterate o
```python ```python
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
client = InferenceClient("http://127.0.0.1:8080") client = InferenceClient(base_url="http://127.0.0.1:8080")
for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): output = client.chat.completions.create(
print(token) messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count to 10"},
],
stream=True,
max_tokens=1024,
)
# To for chunk in output:
# make print(chunk.choices[0].delta.content)
# cheese
#,
# you
# need
# to
# start
# with
# milk
#.
```
If you want additional details, you can add `details=True`. In this case, you get a `TextGenerationStreamResponse` which contains additional information such as the probabilities and the tokens. For the final response in the stream, it also returns the full generated text. # 1
# 2
```python # 3
for details in client.text_generation("How do you make cheese?", max_new_tokens=12, details=True, stream=True): # 4
print(details) # 5
# 6
#TextGenerationStreamResponse(token=Token(id=193, text='\n', logprob=-0.007358551, special=False), generated_text=None, details=None) # 7
#TextGenerationStreamResponse(token=Token(id=2044, text='To', logprob=-1.1357422, special=False), generated_text=None, details=None) # 8
#TextGenerationStreamResponse(token=Token(id=717, text=' make', logprob=-0.009841919, special=False), generated_text=None, details=None) # 9
#... # 10
#TextGenerationStreamResponse(token=Token(id=25, text='.', logprob=-1.3408203, special=False), generated_text='\nTo make cheese, you need to start with milk.', details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None))
``` ```
The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case you need to handle the requests concurrently. The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case you need to handle the requests concurrently.
@ -83,31 +78,46 @@ The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case
```python ```python
from huggingface_hub import AsyncInferenceClient from huggingface_hub import AsyncInferenceClient
client = AsyncInferenceClient("http://127.0.0.1:8080") client = AsyncInferenceClient(base_url="http://127.0.0.1:8080")
async for token in await client.text_generation("How do you make cheese?", stream=True): async def main():
print(token) stream = await client.chat.completions.create(
messages=[{"role": "user", "content": "Say this is a test"}],
stream=True,
)
async for chunk in stream:
print(chunk.choices[0].delta.content or "", end="")
# To asyncio.run(main())
# make
# cheese # This
#, # is
# you # a
# need # test
# to
# start
# with
# milk
#. #.
``` ```
### Streaming with cURL ### Streaming with cURL
To use the `generate_stream` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server To use the OpenAI Chat Completions compatible Messages API `v1/chat/completions` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server
```curl ```curl
curl -N 127.0.0.1:8080/generate_stream \ curl localhost:8080/v1/chat/completions \
-X POST \ -X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -d '{
"model": "tgi",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is deep learning?"
}
],
"stream": true,
"max_tokens": 20
}' \
-H 'Content-Type: application/json' -H 'Content-Type: application/json'
``` ```

View File

@ -12,7 +12,24 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \ docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \ --device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \ --ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel \ ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
--model-id $model --cuda-graphs 0
```
# Using TGI with Intel CPUs
Intel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more.
On a server powered by Intel CPU, TGI can be launched with the following command:
```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
--model-id $model --cuda-graphs 0 --model-id $model --cuda-graphs 0
``` ```

View File

@ -21,7 +21,7 @@ TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPU
## Consuming TGI ## Consuming TGI
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint. Once TGI is running, you can use the `generate` endpoint or the Open AI Chat Completion API compatible [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
<inferencesnippet> <inferencesnippet>
<python> <python>

View File

@ -1,18 +1,31 @@
# Messages API # HTTP API Reference
#### Table of Contents
- [Text Generation Inference custom API](#text-generation-inference-custom-api)
- [OpenAI Messages API](#openai-messages-api)
- [Making a Request](#making-a-request)
- [Streaming](#streaming)
- [Synchronous](#synchronous)
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
- [Cloud Providers](#cloud-providers)
- [Amazon SageMaker](#amazon-sagemaker)
The HTTP API is a RESTful API that allows you to interact with the text-generation-inference component. Two endpoints are available:
* Text Generation Inference [custom API](https://huggingface.github.io/text-generation-inference/)
* OpenAI's [Messages API](#openai-messages-api)
## Text Generation Inference custom API
Check the [API documentation](https://huggingface.github.io/text-generation-inference/) for more information on how to interact with the Text Generation Inference API.
## OpenAI Messages API
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility. Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature. > **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
#### Table of Contents
- [Making a Request](#making-a-request)
- [Streaming](#streaming)
- [Synchronous](#synchronous)
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
- [Cloud Providers](#cloud-providers)
- [Amazon SageMaker](#amazon-sagemaker)
## Making a Request ## Making a Request
You can make a request to TGI's Messages API using `curl`. Here's an example: You can make a request to TGI's Messages API using `curl`. Here's an example:

View File

@ -0,0 +1,30 @@
# Metrics
TGI exposes multiple metrics that can be collected via the `/metrics` Prometheus endpoint.
These metrics can be used to monitor the performance of TGI, autoscale deployment and to help identify bottlenecks.
The following metrics are exposed:
| Metric Name | Description | Type | Unit |
|--------------------------------------------|------------------------------------------------------------------------------------------|-----------|---------|
| `tgi_batch_current_max_tokens` | Maximum tokens for the current batch | Gauge | Count |
| `tgi_batch_current_size` | Current batch size | Gauge | Count |
| `tgi_batch_decode_duration` | Time spent decoding a batch per method (prefill or decode) | Histogram | Seconds |
| `tgi_batch_filter_duration` | Time spent filtering batches and sending generated tokens per method (prefill or decode) | Histogram | Seconds |
| `tgi_batch_forward_duration` | Batch forward duration per method (prefill or decode) | Histogram | Seconds |
| `tgi_batch_inference_count` | Inference calls per method (prefill or decode) | Counter | Count |
| `tgi_batch_inference_duration` | Batch inference duration | Histogram | Seconds |
| `tgi_batch_inference_success` | Number of successful inference calls per method (prefill or decode) | Counter | Count |
| `tgi_batch_next_size` | Batch size of the next batch | Histogram | Count |
| `tgi_queue_size` | Current queue size | Gauge | Count |
| `tgi_request_count` | Total number of requests | Counter | Count |
| `tgi_request_duration` | Total time spent processing the request (e2e latency) | Histogram | Seconds |
| `tgi_request_generated_tokens` | Generated tokens per request | Histogram | Count |
| `tgi_request_inference_duration` | Request inference duration | Histogram | Seconds |
| `tgi_request_input_length` | Input token length per request | Histogram | Count |
| `tgi_request_max_new_tokens` | Maximum new tokens per request | Histogram | Count |
| `tgi_request_mean_time_per_token_duration` | Mean time per token per request (inter-token latency) | Histogram | Seconds |
| `tgi_request_queue_duration` | Time spent in the queue per request | Histogram | Seconds |
| `tgi_request_skipped_tokens` | Speculated tokens per request | Histogram | Count |
| `tgi_request_success` | Number of successful requests | Counter | |
| `tgi_request_validation_duration` | Time spent validating the request | Histogram | Seconds |

View File

@ -1,22 +1,22 @@
# Supported Models and Hardware # Supported Models and Hardware
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported. Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported.
## Supported Models ## Supported Models
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b) - [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/google/gemma2-9b) - [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj) - [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) - [Mistral](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1) - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder) - [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
- [Phi](https://huggingface.co/microsoft/phi-1_5) - [Phi](https://huggingface.co/microsoft/phi-1_5)
@ -32,6 +32,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct) - [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct)
- [Gpt2](https://huggingface.co/openai-community/gpt2) - [Gpt2](https://huggingface.co/openai-community/gpt2)
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) - [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal) - [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)

963
flake.lock Normal file
View File

@ -0,0 +1,963 @@
{
"nodes": {
"cachix": {
"inputs": {
"devenv": [
"crate2nix"
],
"flake-compat": [
"crate2nix"
],
"nixpkgs": "nixpkgs",
"pre-commit-hooks": [
"crate2nix"
]
},
"locked": {
"lastModified": 1709700175,
"narHash": "sha256-A0/6ZjLmT9qdYzKHmevnEIC7G+GiZ4UCr8v0poRPzds=",
"owner": "cachix",
"repo": "cachix",
"rev": "be97b37989f11b724197b5f4c7ffd78f12c8c4bf",
"type": "github"
},
"original": {
"owner": "cachix",
"ref": "latest",
"repo": "cachix",
"type": "github"
}
},
"cachix_2": {
"inputs": {
"devenv": [
"crate2nix",
"crate2nix_stable"
],
"flake-compat": [
"crate2nix",
"crate2nix_stable"
],
"nixpkgs": "nixpkgs_2",
"pre-commit-hooks": [
"crate2nix",
"crate2nix_stable"
]
},
"locked": {
"lastModified": 1716549461,
"narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=",
"owner": "cachix",
"repo": "cachix",
"rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4",
"type": "github"
},
"original": {
"owner": "cachix",
"ref": "latest",
"repo": "cachix",
"type": "github"
}
},
"cachix_3": {
"inputs": {
"devenv": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable"
],
"flake-compat": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable"
],
"nixpkgs": "nixpkgs_3",
"pre-commit-hooks": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable"
]
},
"locked": {
"lastModified": 1716549461,
"narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=",
"owner": "cachix",
"repo": "cachix",
"rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4",
"type": "github"
},
"original": {
"owner": "cachix",
"ref": "latest",
"repo": "cachix",
"type": "github"
}
},
"crate2nix": {
"inputs": {
"cachix": "cachix",
"crate2nix_stable": "crate2nix_stable",
"devshell": "devshell_3",
"flake-compat": "flake-compat_3",
"flake-parts": "flake-parts_3",
"nix-test-runner": "nix-test-runner_3",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"pre-commit-hooks": "pre-commit-hooks_3"
},
"locked": {
"lastModified": 1723311214,
"narHash": "sha256-xdGZQBEa1AC2us/sY3igS/CucWY6jErXsAvCFRhB2LI=",
"owner": "nix-community",
"repo": "crate2nix",
"rev": "236f6addfd452a48be805819e3216af79e988fd5",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "crate2nix",
"type": "github"
}
},
"crate2nix_stable": {
"inputs": {
"cachix": "cachix_2",
"crate2nix_stable": "crate2nix_stable_2",
"devshell": "devshell_2",
"flake-compat": "flake-compat_2",
"flake-parts": "flake-parts_2",
"nix-test-runner": "nix-test-runner_2",
"nixpkgs": "nixpkgs_5",
"pre-commit-hooks": "pre-commit-hooks_2"
},
"locked": {
"lastModified": 1719760004,
"narHash": "sha256-esWhRnt7FhiYq0CcIxw9pvH+ybOQmWBfHYMtleaMhBE=",
"owner": "nix-community",
"repo": "crate2nix",
"rev": "1dee214bb20855fa3e1e7bb98d28922ddaff8c57",
"type": "github"
},
"original": {
"owner": "nix-community",
"ref": "0.14.1",
"repo": "crate2nix",
"type": "github"
}
},
"crate2nix_stable_2": {
"inputs": {
"cachix": "cachix_3",
"crate2nix_stable": "crate2nix_stable_3",
"devshell": "devshell",
"flake-compat": "flake-compat",
"flake-parts": "flake-parts",
"nix-test-runner": "nix-test-runner",
"nixpkgs": "nixpkgs_4",
"pre-commit-hooks": "pre-commit-hooks"
},
"locked": {
"lastModified": 1712821484,
"narHash": "sha256-rGT3CW64cJS9nlnWPFWSc1iEa3dNZecVVuPVGzcsHe8=",
"owner": "nix-community",
"repo": "crate2nix",
"rev": "42883afcad3823fa5811e967fb7bff54bc3c9d6d",
"type": "github"
},
"original": {
"owner": "nix-community",
"ref": "0.14.0",
"repo": "crate2nix",
"type": "github"
}
},
"crate2nix_stable_3": {
"inputs": {
"flake-utils": "flake-utils"
},
"locked": {
"lastModified": 1702842982,
"narHash": "sha256-A9AowkHIjsy1a4LuiPiVP88FMxyCWK41flZEZOUuwQM=",
"owner": "nix-community",
"repo": "crate2nix",
"rev": "75ac2973affa6b9b4f661a7b592cba6e4f51d426",
"type": "github"
},
"original": {
"owner": "nix-community",
"ref": "0.12.0",
"repo": "crate2nix",
"type": "github"
}
},
"devshell": {
"inputs": {
"flake-utils": "flake-utils_2",
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable",
"nixpkgs"
]
},
"locked": {
"lastModified": 1717408969,
"narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=",
"owner": "numtide",
"repo": "devshell",
"rev": "1ebbe68d57457c8cae98145410b164b5477761f4",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "devshell",
"type": "github"
}
},
"devshell_2": {
"inputs": {
"flake-utils": "flake-utils_3",
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
},
"locked": {
"lastModified": 1717408969,
"narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=",
"owner": "numtide",
"repo": "devshell",
"rev": "1ebbe68d57457c8cae98145410b164b5477761f4",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "devshell",
"type": "github"
}
},
"devshell_3": {
"inputs": {
"flake-utils": "flake-utils_4",
"nixpkgs": [
"crate2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1711099426,
"narHash": "sha256-HzpgM/wc3aqpnHJJ2oDqPBkNsqWbW0WfWUO8lKu8nGk=",
"owner": "numtide",
"repo": "devshell",
"rev": "2d45b54ca4a183f2fdcf4b19c895b64fbf620ee8",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "devshell",
"type": "github"
}
},
"flake-compat": {
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"revCount": 57,
"type": "tarball",
"url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz"
},
"original": {
"type": "tarball",
"url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz"
}
},
"flake-compat_2": {
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"revCount": 57,
"type": "tarball",
"url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz"
},
"original": {
"type": "tarball",
"url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz"
}
},
"flake-compat_3": {
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"revCount": 57,
"type": "tarball",
"url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz"
},
"original": {
"type": "tarball",
"url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz"
}
},
"flake-compat_4": {
"locked": {
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-parts": {
"inputs": {
"nixpkgs-lib": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable",
"nixpkgs"
]
},
"locked": {
"lastModified": 1719745305,
"narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"flake-parts_2": {
"inputs": {
"nixpkgs-lib": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
},
"locked": {
"lastModified": 1719745305,
"narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"flake-parts_3": {
"inputs": {
"nixpkgs-lib": [
"crate2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1712014858,
"narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "9126214d0a59633752a136528f5f3b9aa8565b7d",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1694529238,
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_2": {
"inputs": {
"systems": "systems_2"
},
"locked": {
"lastModified": 1701680307,
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_3": {
"inputs": {
"systems": "systems_3"
},
"locked": {
"lastModified": 1701680307,
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_4": {
"inputs": {
"systems": "systems_4"
},
"locked": {
"lastModified": 1701680307,
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_5": {
"inputs": {
"systems": "systems_5"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"flake-utils_6": {
"inputs": {
"systems": "systems_6"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"gitignore": {
"inputs": {
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable",
"pre-commit-hooks",
"nixpkgs"
]
},
"locked": {
"lastModified": 1709087332,
"narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
"owner": "hercules-ci",
"repo": "gitignore.nix",
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "gitignore.nix",
"type": "github"
}
},
"gitignore_2": {
"inputs": {
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"pre-commit-hooks",
"nixpkgs"
]
},
"locked": {
"lastModified": 1709087332,
"narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
"owner": "hercules-ci",
"repo": "gitignore.nix",
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "gitignore.nix",
"type": "github"
}
},
"gitignore_3": {
"inputs": {
"nixpkgs": [
"crate2nix",
"pre-commit-hooks",
"nixpkgs"
]
},
"locked": {
"lastModified": 1709087332,
"narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
"owner": "hercules-ci",
"repo": "gitignore.nix",
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "gitignore.nix",
"type": "github"
}
},
"nix-filter": {
"locked": {
"lastModified": 1710156097,
"narHash": "sha256-1Wvk8UP7PXdf8bCCaEoMnOT1qe5/Duqgj+rL8sRQsSM=",
"owner": "numtide",
"repo": "nix-filter",
"rev": "3342559a24e85fc164b295c3444e8a139924675b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "nix-filter",
"type": "github"
}
},
"nix-test-runner": {
"flake": false,
"locked": {
"lastModified": 1588761593,
"narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=",
"owner": "stoeffel",
"repo": "nix-test-runner",
"rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2",
"type": "github"
},
"original": {
"owner": "stoeffel",
"repo": "nix-test-runner",
"type": "github"
}
},
"nix-test-runner_2": {
"flake": false,
"locked": {
"lastModified": 1588761593,
"narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=",
"owner": "stoeffel",
"repo": "nix-test-runner",
"rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2",
"type": "github"
},
"original": {
"owner": "stoeffel",
"repo": "nix-test-runner",
"type": "github"
}
},
"nix-test-runner_3": {
"flake": false,
"locked": {
"lastModified": 1588761593,
"narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=",
"owner": "stoeffel",
"repo": "nix-test-runner",
"rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2",
"type": "github"
},
"original": {
"owner": "stoeffel",
"repo": "nix-test-runner",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1700612854,
"narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs_2": {
"locked": {
"lastModified": 1715534503,
"narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "2057814051972fa1453ddfb0d98badbea9b83c06",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs_3": {
"locked": {
"lastModified": 1715534503,
"narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "2057814051972fa1453ddfb0d98badbea9b83c06",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs_4": {
"locked": {
"lastModified": 1719506693,
"narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=",
"path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source",
"rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a",
"type": "path"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"nixpkgs_5": {
"locked": {
"lastModified": 1719506693,
"narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=",
"path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source",
"rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a",
"type": "path"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"nixpkgs_6": {
"locked": {
"lastModified": 1723912943,
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "cuda-12.4",
"repo": "nixpkgs",
"type": "github"
}
},
"pre-commit-hooks": {
"inputs": {
"flake-compat": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable",
"flake-compat"
],
"gitignore": "gitignore",
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable",
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"crate2nix_stable",
"crate2nix_stable",
"nixpkgs"
]
},
"locked": {
"lastModified": 1719259945,
"narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=",
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"type": "github"
}
},
"pre-commit-hooks_2": {
"inputs": {
"flake-compat": [
"crate2nix",
"crate2nix_stable",
"flake-compat"
],
"gitignore": "gitignore_2",
"nixpkgs": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"crate2nix_stable",
"nixpkgs"
]
},
"locked": {
"lastModified": 1719259945,
"narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=",
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"type": "github"
}
},
"pre-commit-hooks_3": {
"inputs": {
"flake-compat": [
"crate2nix",
"flake-compat"
],
"flake-utils": "flake-utils_5",
"gitignore": "gitignore_3",
"nixpkgs": [
"crate2nix",
"nixpkgs"
],
"nixpkgs-stable": [
"crate2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1712055707,
"narHash": "sha256-4XLvuSIDZJGS17xEwSrNuJLL7UjDYKGJSbK1WWX2AK8=",
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"rev": "e35aed5fda3cc79f88ed7f1795021e559582093a",
"type": "github"
},
"original": {
"owner": "cachix",
"repo": "pre-commit-hooks.nix",
"type": "github"
}
},
"root": {
"inputs": {
"crate2nix": "crate2nix",
"flake-utils": "flake-utils_6",
"nix-filter": "nix-filter",
"nixpkgs": [
"tgi-nix",
"nixpkgs"
],
"rust-overlay": "rust-overlay",
"tgi-nix": "tgi-nix"
}
},
"rust-overlay": {
"inputs": {
"nixpkgs": [
"tgi-nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1724638882,
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
"type": "github"
},
"original": {
"owner": "oxalica",
"repo": "rust-overlay",
"type": "github"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_2": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_3": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_4": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_5": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"systems_6": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat_4",
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1724784743,
"narHash": "sha256-NdEoWeNwR/ZstYnHaiQWIYZvr7VsrAh7g3+ZHUPrxuI=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "c9580c3e39a855246bb87b584bbea1885b44f524",
"type": "github"
},
"original": {
"owner": "danieldk",
"repo": "tgi-nix",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

112
flake.nix Normal file
View File

@ -0,0 +1,112 @@
{
inputs = {
crate2nix = {
url = "github:nix-community/crate2nix";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
url = "github:oxalica/rust-overlay";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
};
outputs =
{
self,
crate2nix,
nix-filter,
nixpkgs,
flake-utils,
rust-overlay,
tgi-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
cargoNix = crate2nix.tools.${system}.appliedCargoNix {
name = "tgi";
src = ./.;
additionalCargoNixArgs = [ "--all-features" ];
};
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
overlays = [
rust-overlay.overlays.default
tgi-nix.overlays.default
];
};
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {
inherit crateOverrides;
};
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
inherit crateOverrides;
};
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
inherit crateOverrides;
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
in
{
devShells = with pkgs; rec {
default = pure;
pure = mkShell {
buildInputs = [
benchmark
launcher
router
server
];
};
impure = mkShell {
buildInputs =
[
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
protobuf
]
++ (with python3.pkgs; [
venvShellHook
pip
ipdb
]);
inputsFrom = [ server ];
venvDir = "./.venv";
postVenv = ''
unset SOURCE_DATE_EPOCH
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin
'';
};
};
packages.default = pkgs.writeShellApplication {
name = "text-generation-inference";
runtimeInputs = [
server
router
];
text = ''
${launcher}/bin/text-generation-launcher "$@"
'';
};
}
);
}

View File

@ -118,6 +118,7 @@ class ResponseComparator(JSONSnapshotExtension):
and token.text == other.text and token.text == other.text
and ( and (
self.ignore_logprob self.ignore_logprob
or (token.logprob == other.logprob and token.logprob is None)
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol) or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
) )
and token.special == other.special and token.special == other.special
@ -256,7 +257,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
def _inner_health(self): def _inner_health(self):
raise NotImplementedError raise NotImplementedError

View File

@ -41,22 +41,22 @@
}, },
{ {
"id": 1669, "id": 1669,
"logprob": -1.5664062, "logprob": -1.5595703,
"text": " il" "text": " il"
}, },
{ {
"id": 11580, "id": 11580,
"logprob": -0.94189453, "logprob": -0.9428711,
"text": " faut" "text": " faut"
}, },
{ {
"id": 3913, "id": 3913,
"logprob": -3.6816406, "logprob": -3.703125,
"text": " tout" "text": " tout"
}, },
{ {
"id": 39261, "id": 39261,
"logprob": -1.7753906, "logprob": -1.7763672,
"text": " d'abord" "text": " d'abord"
} }
], ],
@ -64,7 +64,7 @@
"tokens": [ "tokens": [
{ {
"id": 578, "id": 578,
"logprob": -1.6318359, "logprob": -1.7822266,
"special": false, "special": false,
"text": " le" "text": " le"
}, },
@ -76,7 +76,7 @@
}, },
{ {
"id": 7735, "id": 7735,
"logprob": -2.4355469, "logprob": -2.4199219,
"special": false, "special": false,
"text": " fond" "text": " fond"
}, },
@ -88,19 +88,19 @@
}, },
{ {
"id": 693, "id": 693,
"logprob": -2.4472656, "logprob": -2.4628906,
"special": false, "special": false,
"text": " à" "text": " à"
}, },
{ {
"id": 366, "id": 366,
"logprob": -1.1972656, "logprob": -1.1308594,
"special": false, "special": false,
"text": " la" "text": " la"
}, },
{ {
"id": 48844, "id": 48844,
"logprob": -1.7890625, "logprob": -1.7900391,
"special": false, "special": false,
"text": " cass" "text": " cass"
}, },
@ -118,7 +118,7 @@
}, },
{ {
"id": 2940, "id": 2940,
"logprob": -1.9335938, "logprob": -1.9306641,
"special": false, "special": false,
"text": " avec" "text": " avec"
} }

View File

@ -5,7 +5,7 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas", "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C",
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": null "tool_calls": null
@ -13,14 +13,14 @@
"usage": null "usage": null
} }
], ],
"created": 1716553098, "created": 1724792495,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "chat.completion",
"system_fingerprint": "2.0.5-dev0-native", "system_fingerprint": "2.2.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 100, "completion_tokens": 100,
"prompt_tokens": 62, "prompt_tokens": 61,
"total_tokens": 162 "total_tokens": 161
} }
} }

View File

@ -8,11 +8,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -23,11 +23,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -38,11 +38,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -53,11 +53,11 @@
"text": "hd" "text": "hd"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -68,11 +68,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -83,11 +83,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -98,11 +98,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -113,11 +113,11 @@
"text": "aho" "text": "aho"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -128,11 +128,11 @@
"text": "2" "text": "2"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -143,11 +143,11 @@
"text": "2" "text": "2"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -158,11 +158,11 @@
"text": "2" "text": "2"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -173,11 +173,11 @@
"text": "ima" "text": "ima"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -188,11 +188,11 @@
"text": "." "text": "."
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -203,11 +203,11 @@
"text": "." "text": "."
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -218,11 +218,11 @@
"text": "." "text": "."
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -233,11 +233,11 @@
"text": "\n" "text": "\n"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -248,11 +248,11 @@
"text": " Sarah" "text": " Sarah"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -263,11 +263,11 @@
"text": " Yes" "text": " Yes"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -278,11 +278,11 @@
"text": " And" "text": " And"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -293,11 +293,11 @@
"text": "i" "text": "i"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -308,11 +308,11 @@
"text": "'" "text": "'"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -323,11 +323,11 @@
"text": "," "text": ","
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -338,11 +338,11 @@
"text": " what" "text": " what"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -353,11 +353,11 @@
"text": "'" "text": "'"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -368,11 +368,11 @@
"text": "s" "text": "s"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -383,11 +383,11 @@
"text": " Moh" "text": " Moh"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -398,11 +398,11 @@
"text": " is" "text": " is"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -413,11 +413,11 @@
"text": "m" "text": "m"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -428,11 +428,11 @@
"text": " Room" "text": " Room"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -443,11 +443,11 @@
"text": "s" "text": "s"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -458,11 +458,11 @@
"text": " the" "text": " the"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -473,11 +473,11 @@
"text": " tired" "text": " tired"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -488,11 +488,11 @@
"text": ":" "text": ":"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -503,11 +503,11 @@
"text": "'" "text": "'"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -518,11 +518,11 @@
"text": " capital" "text": " capital"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
@ -530,73 +530,73 @@
"finish_reason": "", "finish_reason": "",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " of" "text": ","
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"text": " She" "text": " She"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 1, "index": 1,
"logprobs": null, "logprobs": null,
"text": " scale" "text": " scale"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 2, "index": 2,
"logprobs": null, "logprobs": null,
"text": " of" "text": " of"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
}, },
{ {
"choices": [ "choices": [
{ {
"finish_reason": "", "finish_reason": "length",
"index": 3, "index": 3,
"logprobs": null, "logprobs": null,
"text": " being" "text": " its"
} }
], ],
"created": 1713284431, "created": 1724833943,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
"system_fingerprint": "2.0.1-native" "system_fingerprint": "2.2.1-dev0-native"
} }
] ]

View File

@ -16,7 +16,7 @@
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -11.1875, "logprob": -11.25,
"text": " request" "text": " request"
} }
], ],
@ -24,66 +24,66 @@
"tokens": [ "tokens": [
{ {
"id": 185, "id": 185,
"logprob": -1.5546875, "logprob": -1.546875,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 549, "id": 549,
"logprob": -2.84375, "logprob": -2.859375,
"special": false, "special": false,
"text": "The" "text": "The"
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.34375, "logprob": -2.484375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.8359375, "logprob": -0.83203125,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.0859375, "logprob": -1.1484375,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 254, "id": 245,
"logprob": -1.5390625, "logprob": -1.578125,
"special": false, "special": false,
"text": " the" "text": " a"
}, },
{ {
"id": 1022, "id": 3412,
"logprob": -1.1875, "logprob": -2.578125,
"special": false, "special": false,
"text": " first" "text": " document"
}, },
{ {
"id": 3458, "id": 344,
"logprob": -0.35546875, "logprob": -1.125,
"special": false, "special": false,
"text": " step" "text": " that"
}, },
{ {
"id": 279, "id": 317,
"logprob": -0.8828125, "logprob": -1.6953125,
"special": false, "special": false,
"text": " in" "text": " is"
}, },
{ {
"id": 254, "id": 1222,
"logprob": -0.71484375, "logprob": -1.71875,
"special": false, "special": false,
"text": " the" "text": " used"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is the first step in the" "generated_text": "\nThe test request is a document that is used"
} }

View File

@ -37,56 +37,56 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
}, },
{ {
"details": { "details": {
@ -126,56 +126,56 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
}, },
{ {
"details": { "details": {
@ -215,56 +215,56 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
}, },
{ {
"details": { "details": {
@ -304,55 +304,55 @@
}, },
{ {
"id": 1727, "id": 1727,
"logprob": -2.359375, "logprob": -2.4375,
"special": false, "special": false,
"text": " test" "text": " test"
}, },
{ {
"id": 3102, "id": 3102,
"logprob": -0.83203125, "logprob": -0.83984375,
"special": false, "special": false,
"text": " request" "text": " request"
}, },
{ {
"id": 317, "id": 317,
"logprob": -1.125, "logprob": -1.1328125,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 245, "id": 254,
"logprob": -1.5703125, "logprob": -1.515625,
"special": false, "special": false,
"text": " a" "text": " the"
}, },
{ {
"id": 3412, "id": 1022,
"logprob": -2.578125, "logprob": -1.15625,
"special": false, "special": false,
"text": " document" "text": " first"
}, },
{ {
"id": 344, "id": 3458,
"logprob": -1.125, "logprob": -0.3671875,
"special": false, "special": false,
"text": " that" "text": " step"
}, },
{ {
"id": 317, "id": 279,
"logprob": -1.6953125, "logprob": -0.88671875,
"special": false, "special": false,
"text": " is" "text": " in"
}, },
{ {
"id": 1222, "id": 254,
"logprob": -1.75, "logprob": -0.69140625,
"special": false, "special": false,
"text": " used" "text": " the"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\nThe test request is a document that is used" "generated_text": "\nThe test request is the first step in the"
} }
] ]

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "stop_sequence",
"generated_tokens": 10, "generated_tokens": 5,
"prefill": [ "prefill": [
{ {
"id": 128000, "id": 128000,
@ -16,7 +16,7 @@
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.375, "logprob": -10.4375,
"text": " request" "text": " request"
} }
], ],
@ -29,61 +29,31 @@
"text": ":" "text": ":"
}, },
{ {
"id": 2209, "id": 923,
"logprob": -2.78125, "logprob": -2.84375,
"special": false, "special": false,
"text": " Is" "text": " add"
}, },
{ {
"id": 279, "id": 264,
"logprob": -0.6328125, "logprob": 0.0,
"special": false, "special": false,
"text": " the" "text": " a"
},
{
"id": 734,
"logprob": -2.703125,
"special": false,
"text": " function"
}, },
{ {
"id": 330, "id": 330,
"logprob": -0.34179688, "logprob": -0.31640625,
"special": false, "special": false,
"text": " \"" "text": " \""
}, },
{ {
"id": 4110, "id": 1985,
"logprob": -2.359375, "logprob": 0.0,
"special": false, "special": false,
"text": "Create" "text": "test"
},
{
"id": 7575,
"logprob": -2.1875,
"special": false,
"text": "Process"
},
{
"id": 1,
"logprob": -0.07910156,
"special": false,
"text": "\""
},
{
"id": 304,
"logprob": -0.83203125,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8203125,
"special": false,
"text": " Win"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request: Is the function \"CreateProcess\" in Win" "generated_text": "Test request: add a \"test"
} }

View File

@ -12,12 +12,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -25,61 +25,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }
@ -101,12 +101,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -114,61 +114,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }
@ -190,12 +190,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -203,61 +203,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }
@ -279,12 +279,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -292,61 +292,61 @@
"tokens": [ "tokens": [
{ {
"id": 369, "id": 369,
"logprob": -2.1816406, "logprob": -2.15625,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 279, "id": 279,
"logprob": -2.6992188, "logprob": -2.703125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 220, "id": 220,
"logprob": -3.6308594, "logprob": -3.640625,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 679, "id": 679,
"logprob": -1.7988281, "logprob": -1.703125,
"special": false, "special": false,
"text": "201" "text": "201"
}, },
{ {
"id": 24, "id": 24,
"logprob": -1.3535156, "logprob": -1.421875,
"special": false, "special": false,
"text": "9" "text": "9"
}, },
{ {
"id": 12, "id": 12,
"logprob": -2.0058594, "logprob": -2.03125,
"special": false, "special": false,
"text": "-" "text": "-"
}, },
{ {
"id": 2366, "id": 2366,
"logprob": -0.45410156, "logprob": -0.49023438,
"special": false, "special": false,
"text": "202" "text": "202"
}, },
{ {
"id": 15, "id": 15,
"logprob": -0.037109375, "logprob": -0.041503906,
"special": false, "special": false,
"text": "0" "text": "0"
}, },
{ {
"id": 2978, "id": 2978,
"logprob": -0.8095703, "logprob": -0.87109375,
"special": false, "special": false,
"text": " school" "text": " school"
}, },
{ {
"id": 1060, "id": 1060,
"logprob": -0.013053894, "logprob": -0.012939453,
"special": false, "special": false,
"text": " year" "text": " year"
} }

View File

@ -8,13 +8,13 @@
"tokens": [ "tokens": [
{ {
"id": 54901, "id": 54901,
"logprob": -0.72753906, "logprob": -0.84765625,
"special": false, "special": false,
"text": "beach" "text": "beach"
}, },
{ {
"id": 1, "id": 1,
"logprob": -0.011009216, "logprob": -0.008666992,
"special": true, "special": true,
"text": "<eos>" "text": "<eos>"
} }

View File

@ -19,25 +19,25 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.19421387, "logprob": -0.28955078,
"special": false, "special": false,
"text": " to" "text": " to"
}, },
{ {
"id": 3758, "id": 3758,
"logprob": -0.62597656, "logprob": -0.7739258,
"special": false, "special": false,
"text": " send" "text": " send"
}, },
{ {
"id": 1366, "id": 1366,
"logprob": -0.87060547, "logprob": -0.85253906,
"special": false, "special": false,
"text": " data" "text": " data"
}, },
{ {
"id": 625, "id": 625,
"logprob": -0.88427734, "logprob": -0.8984375,
"special": false, "special": false,
"text": " over" "text": " over"
}, },
@ -49,7 +49,7 @@
}, },
{ {
"id": 3127, "id": 3127,
"logprob": -1.9462891, "logprob": -1.9404297,
"special": false, "special": false,
"text": " network" "text": " network"
} }

View File

@ -16,7 +16,7 @@
}, },
{ {
"id": 100, "id": 100,
"logprob": -0.38549805, "logprob": -0.38305664,
"text": "_" "text": "_"
}, },
{ {
@ -29,7 +29,7 @@
"tokens": [ "tokens": [
{ {
"id": 2284, "id": 2284,
"logprob": -0.31323242, "logprob": -0.296875,
"special": false, "special": false,
"text": "():" "text": "():"
}, },
@ -59,19 +59,19 @@
}, },
{ {
"id": 10914, "id": 10914,
"logprob": -0.7817383, "logprob": -0.7734375,
"special": false, "special": false,
"text": " World" "text": " World"
}, },
{ {
"id": 16013, "id": 16013,
"logprob": -0.6328125, "logprob": -0.61816406,
"special": false, "special": false,
"text": "!\")" "text": "!\")"
}, },
{ {
"id": 222, "id": 222,
"logprob": -0.0619812, "logprob": -0.054870605,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -83,7 +83,7 @@
}, },
{ {
"id": 610, "id": 610,
"logprob": -0.4086914, "logprob": -0.4152832,
"special": false, "special": false,
"text": "def" "text": "def"
}, },
@ -113,7 +113,7 @@
}, },
{ {
"id": 444, "id": 444,
"logprob": -0.21826172, "logprob": -0.21618652,
"special": false, "special": false,
"text": "name" "text": "name"
}, },
@ -173,7 +173,7 @@
}, },
{ {
"id": 11571, "id": 11571,
"logprob": -0.10021973, "logprob": -0.08892822,
"special": false, "special": false,
"text": "!\"" "text": "!\""
}, },

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 20, "generated_tokens": 2,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
@ -11,57 +11,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.5859375, "logprob": -8.9453125,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -7.5859375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.2668457, "logprob": -0.21875,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.6416016, "logprob": -1.2773438,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.22705078, "logprob": -0.25195312,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.2304688, "logprob": -4.8203125,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.0976562, "logprob": -3.7734375,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -1.1044922, "logprob": -0.8310547,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.14294434, "logprob": -0.22766113,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.32299805, "logprob": -0.46240234,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.8164062, "logprob": -3.0234375,
"text": "]):" "text": "]):"
} }
], ],
@ -69,126 +69,18 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.1282959, "logprob": -0.04626465,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
{ {
"id": 1524, "id": 0,
"logprob": -0.97998047, "logprob": null,
"special": false, "special": true,
"text": " \"\"\"" "text": "<|endoftext|>"
},
{
"id": 284,
"logprob": -0.7006836,
"special": false,
"text": "\n "
},
{
"id": 14883,
"logprob": -2.1933594,
"special": false,
"text": " Calculate"
},
{
"id": 322,
"logprob": -0.2697754,
"special": false,
"text": " the"
},
{
"id": 3226,
"logprob": -0.0836792,
"special": false,
"text": " ge"
},
{
"id": 21017,
"logprob": -0.018737793,
"special": false,
"text": "ometric"
},
{
"id": 5651,
"logprob": -0.028640747,
"special": false,
"text": " mean"
},
{
"id": 432,
"logprob": -0.29467773,
"special": false,
"text": " of"
},
{
"id": 312,
"logprob": -0.31518555,
"special": false,
"text": " a"
},
{
"id": 1149,
"logprob": -0.20605469,
"special": false,
"text": " list"
},
{
"id": 432,
"logprob": -0.23254395,
"special": false,
"text": " of"
},
{
"id": 7515,
"logprob": -0.4489746,
"special": false,
"text": " numbers"
},
{
"id": 32,
"logprob": -0.6044922,
"special": false,
"text": "."
},
{
"id": 446,
"logprob": -0.63964844,
"special": false,
"text": "\n\n "
},
{
"id": 499,
"logprob": -1.1953125,
"special": false,
"text": " :"
},
{
"id": 753,
"logprob": -0.03515625,
"special": false,
"text": "param"
},
{
"id": 498,
"logprob": -0.06311035,
"special": false,
"text": " L"
},
{
"id": 44,
"logprob": -0.003414154,
"special": false,
"text": ":"
},
{
"id": 1682,
"logprob": -1.3310547,
"special": false,
"text": " List"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" "generated_text": "\n "
} }

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 20, "generated_tokens": 2,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
@ -11,57 +11,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.5859375, "logprob": -8.9453125,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -7.5898438, "logprob": -8.859375,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.26586914, "logprob": -0.21984863,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.6347656, "logprob": -1.2861328,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.22705078, "logprob": -0.25219727,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.2382812, "logprob": -4.8007812,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.0996094, "logprob": -3.7949219,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -1.1025391, "logprob": -0.8046875,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.14294434, "logprob": -0.22424316,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.32226562, "logprob": -0.46191406,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.8164062, "logprob": -3.0253906,
"text": "]):" "text": "]):"
} }
], ],
@ -74,121 +74,13 @@
"text": "\n " "text": "\n "
}, },
{ {
"id": 442, "id": 0,
"logprob": -1.3134766, "logprob": null,
"special": false, "special": true,
"text": " return" "text": "<|endoftext|>"
},
{
"id": 11665,
"logprob": -0.10021973,
"special": false,
"text": " reduce"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 5962,
"logprob": 0.0,
"special": false,
"text": "lambda"
},
{
"id": 816,
"logprob": 0.0,
"special": false,
"text": " x"
},
{
"id": 30,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 533,
"logprob": 0.0,
"special": false,
"text": " y"
},
{
"id": 44,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 816,
"logprob": 0.0,
"special": false,
"text": " x"
},
{
"id": 319,
"logprob": -0.42871094,
"special": false,
"text": " *"
},
{
"id": 533,
"logprob": 0.0,
"special": false,
"text": " y"
},
{
"id": 30,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 498,
"logprob": 0.0,
"special": false,
"text": " L"
},
{
"id": 27,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 1115,
"logprob": 0.0,
"special": false,
"text": " **"
},
{
"id": 308,
"logprob": 0.0,
"special": false,
"text": " ("
},
{
"id": 35,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 32,
"logprob": -0.31323242,
"special": false,
"text": "."
},
{
"id": 34,
"logprob": 0.0,
"special": false,
"text": "0"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" "generated_text": "\n "
} }

View File

@ -2,8 +2,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 10, "generated_tokens": 2,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
@ -12,57 +12,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.5859375, "logprob": -8.9453125,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -7.5820312, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.26708984, "logprob": -0.22033691,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.6386719, "logprob": -1.2939453,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.22717285, "logprob": -0.25268555,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.234375, "logprob": -4.796875,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.1015625, "logprob": -3.796875,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -1.1083984, "logprob": -0.8066406,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.14294434, "logprob": -0.22644043,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.32592773, "logprob": -0.46166992,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.8164062, "logprob": -3.0253906,
"text": "]):" "text": "]):"
} }
], ],
@ -70,74 +70,26 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.12817383, "logprob": -0.046844482,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
{ {
"id": 1524, "id": 0,
"logprob": -0.9863281, "logprob": null,
"special": false, "special": true,
"text": " \"\"\"" "text": "<|endoftext|>"
},
{
"id": 284,
"logprob": -0.7011719,
"special": false,
"text": "\n "
},
{
"id": 14883,
"logprob": -2.2050781,
"special": false,
"text": " Calculate"
},
{
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": " the"
},
{
"id": 3226,
"logprob": -0.08465576,
"special": false,
"text": " ge"
},
{
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": "ometric"
},
{
"id": 5651,
"logprob": -0.028625488,
"special": false,
"text": " mean"
},
{
"id": 432,
"logprob": -0.29418945,
"special": false,
"text": " of"
},
{
"id": 312,
"logprob": -0.3161621,
"special": false,
"text": " a"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a" "generated_text": "\n "
}, },
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 10, "generated_tokens": 2,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
@ -146,57 +98,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.5859375, "logprob": -8.9375,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -7.59375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.26953125, "logprob": -0.21826172,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.640625, "logprob": -1.2871094,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.22705078, "logprob": -0.25390625,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.234375, "logprob": -4.8085938,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.1132812, "logprob": -3.7890625,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -1.1123047, "logprob": -0.8076172,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.14294434, "logprob": -0.22302246,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.32299805, "logprob": -0.46435547,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.8164062, "logprob": -3.0234375,
"text": "]):" "text": "]):"
} }
], ],
@ -204,74 +156,26 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.12854004, "logprob": -0.046722412,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
{ {
"id": 1524, "id": 0,
"logprob": -0.9897461, "logprob": null,
"special": false, "special": true,
"text": " \"\"\"" "text": "<|endoftext|>"
},
{
"id": 284,
"logprob": -0.69970703,
"special": false,
"text": "\n "
},
{
"id": 14883,
"logprob": -2.2050781,
"special": false,
"text": " Calculate"
},
{
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": " the"
},
{
"id": 3226,
"logprob": -0.08496094,
"special": false,
"text": " ge"
},
{
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": "ometric"
},
{
"id": 5651,
"logprob": -0.029037476,
"special": false,
"text": " mean"
},
{
"id": 432,
"logprob": -0.2939453,
"special": false,
"text": " of"
},
{
"id": 312,
"logprob": -0.31591797,
"special": false,
"text": " a"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a" "generated_text": "\n "
}, },
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 10, "generated_tokens": 2,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
@ -280,57 +184,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.5859375, "logprob": -8.9453125,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -7.5859375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.26586914, "logprob": -0.21813965,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.6347656, "logprob": -1.2744141,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.22766113, "logprob": -0.2512207,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.2265625, "logprob": -4.8046875,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.0976562, "logprob": -3.7851562,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -1.1025391, "logprob": -0.81396484,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.1427002, "logprob": -0.22570801,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.32592773, "logprob": -0.46044922,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.8164062, "logprob": -3.0234375,
"text": "]):" "text": "]):"
} }
], ],
@ -338,74 +242,26 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.13012695, "logprob": -0.04650879,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
{ {
"id": 1524, "id": 0,
"logprob": -0.98046875, "logprob": null,
"special": false, "special": true,
"text": " \"\"\"" "text": "<|endoftext|>"
},
{
"id": 284,
"logprob": -0.69921875,
"special": false,
"text": "\n "
},
{
"id": 14883,
"logprob": -2.1992188,
"special": false,
"text": " Calculate"
},
{
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": " the"
},
{
"id": 3226,
"logprob": -0.083496094,
"special": false,
"text": " ge"
},
{
"id": 21017,
"logprob": -0.01902771,
"special": false,
"text": "ometric"
},
{
"id": 5651,
"logprob": -0.029006958,
"special": false,
"text": " mean"
},
{
"id": 432,
"logprob": -0.29248047,
"special": false,
"text": " of"
},
{
"id": 312,
"logprob": -0.3161621,
"special": false,
"text": " a"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a" "generated_text": "\n "
}, },
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "length", "finish_reason": "eos_token",
"generated_tokens": 10, "generated_tokens": 2,
"prefill": [ "prefill": [
{ {
"id": 589, "id": 589,
@ -414,57 +270,57 @@
}, },
{ {
"id": 3226, "id": 3226,
"logprob": -8.5859375, "logprob": -8.9453125,
"text": " ge" "text": " ge"
}, },
{ {
"id": 21017, "id": 21017,
"logprob": -7.5859375, "logprob": -8.8515625,
"text": "ometric" "text": "ometric"
}, },
{ {
"id": 81, "id": 81,
"logprob": -0.26904297, "logprob": -0.21960449,
"text": "_" "text": "_"
}, },
{ {
"id": 6009, "id": 6009,
"logprob": -1.6386719, "logprob": -1.2890625,
"text": "mean" "text": "mean"
}, },
{ {
"id": 26, "id": 26,
"logprob": -0.22705078, "logprob": -0.25073242,
"text": "(" "text": "("
}, },
{ {
"id": 62, "id": 62,
"logprob": -5.234375, "logprob": -4.8085938,
"text": "L" "text": "L"
}, },
{ {
"id": 44, "id": 44,
"logprob": -3.1132812, "logprob": -3.8046875,
"text": ":" "text": ":"
}, },
{ {
"id": 1682, "id": 1682,
"logprob": -1.1074219, "logprob": -0.8071289,
"text": " List" "text": " List"
}, },
{ {
"id": 77, "id": 77,
"logprob": -0.14477539, "logprob": -0.22570801,
"text": "[" "text": "["
}, },
{ {
"id": 1808, "id": 1808,
"logprob": -0.3256836, "logprob": -0.46118164,
"text": "float" "text": "float"
}, },
{ {
"id": 10794, "id": 10794,
"logprob": -2.8027344, "logprob": -3.0097656,
"text": "]):" "text": "]):"
} }
], ],
@ -472,67 +328,19 @@
"tokens": [ "tokens": [
{ {
"id": 284, "id": 284,
"logprob": -0.12915039, "logprob": -0.046539307,
"special": false, "special": false,
"text": "\n " "text": "\n "
}, },
{ {
"id": 1524, "id": 0,
"logprob": -0.98535156, "logprob": null,
"special": false, "special": true,
"text": " \"\"\"" "text": "<|endoftext|>"
},
{
"id": 284,
"logprob": -0.69921875,
"special": false,
"text": "\n "
},
{
"id": 14883,
"logprob": -2.2011719,
"special": false,
"text": " Calculate"
},
{
"id": 322,
"logprob": -0.26708984,
"special": false,
"text": " the"
},
{
"id": 3226,
"logprob": -0.08502197,
"special": false,
"text": " ge"
},
{
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": "ometric"
},
{
"id": 5651,
"logprob": -0.028625488,
"special": false,
"text": " mean"
},
{
"id": 432,
"logprob": -0.29589844,
"special": false,
"text": " of"
},
{
"id": 312,
"logprob": -0.31591797,
"special": false,
"text": " a"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a" "generated_text": "\n "
} }
] ]

View File

@ -30,19 +30,19 @@
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37573242, "logprob": -0.38061523,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 633, "id": 633,
"logprob": -0.09161377, "logprob": -0.09301758,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 4480, "id": 4480,
"logprob": -0.26171875, "logprob": -0.26782227,
"special": false, "special": false,
"text": " feature" "text": " feature"
}, },
@ -78,7 +78,7 @@
}, },
{ {
"id": 13, "id": 13,
"logprob": 0.0, "logprob": -0.10632324,
"special": false, "special": false,
"text": "\n" "text": "\n"
} }

View File

@ -26,13 +26,13 @@
}, },
{ {
"id": 259, "id": 259,
"logprob": -0.4716797, "logprob": -0.46948242,
"special": false, "special": false,
"text": " " "text": " "
}, },
{ {
"id": 261, "id": 261,
"logprob": -0.044677734, "logprob": -0.15307617,
"special": false, "special": false,
"text": "," "text": ","
}, },
@ -56,7 +56,7 @@
}, },
{ {
"id": 35622, "id": 35622,
"logprob": -1.1630859, "logprob": -1.2998047,
"special": false, "special": false,
"text": " cloud" "text": " cloud"
}, },

View File

@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
print(repr(response.choices[0].message.content)) print(repr(response.choices[0].message.content))
assert ( assert (
response.choices[0].message.content response.choices[0].message.content
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas" == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
return flash_llama_exl2_handle.client return flash_llama_exl2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_all_params( async def test_flash_llama_exl2_all_params(
@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_load( async def test_flash_llama_exl2_load(

View File

@ -21,6 +21,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
"Test request", max_new_tokens=10, decoder_input_details=True "Test request", max_new_tokens=10, decoder_input_details=True
) )
assert response.generated_text == " for the 2019-2020 school year"
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == response_snapshot
@ -57,6 +58,8 @@ async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_sna
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert responses[0].generated_text == " for the 2019-2020 school year"
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"Different messages : {[r.generated_text for r in responses]}"
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -21,7 +21,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
max_new_tokens=20, max_new_tokens=20,
decoder_input_details=True, decoder_input_details=True,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 2
assert response == generous_response_snapshot assert response == generous_response_snapshot
@ -38,7 +38,7 @@ async def test_flash_starcoder_gptq_default_params(
decoder_input_details=True, decoder_input_details=True,
seed=0, seed=0,
) )
assert response.details.generated_tokens == 20 assert response.details.generated_tokens == 2
assert response == generous_response_snapshot assert response == generous_response_snapshot

View File

@ -98,6 +98,6 @@ async def test_grammar_response_format_llama_error_if_tools_not_installed(
# 422 means the server was unable to process the request because it contains invalid data. # 422 means the server was unable to process the request because it contains invalid data.
assert response.status_code == 422 assert response.status_code == 422
assert response.json() == { assert response.json() == {
"error": "Grammar and tools are mutually exclusive", "error": "Tool error: Grammar and tools are mutually exclusive",
"error_type": "grammar and tools", "error_type": "tool_error",
} }

View File

@ -62,6 +62,7 @@ async def test_mamba_load(
) )
assert len(responses) == 4 assert len(responses) == 4
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"

View File

@ -0,0 +1,19 @@
import pytest
@pytest.fixture(scope="module")
def opt_sharded_handle(launcher):
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def opt_sharded(opt_sharded_handle):
await opt_sharded_handle.health(300)
return opt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_opt(opt_sharded):
pass

View File

@ -36,6 +36,7 @@ tools = [
}, },
}, },
"required": ["location", "format"], "required": ["location", "format"],
"additionalProperties": False,
}, },
}, },
}, },
@ -62,13 +63,13 @@ tools = [
}, },
}, },
"required": ["location", "format", "num_days"], "required": ["location", "format", "num_days"],
"additionalProperties": False,
}, },
}, },
}, },
] ]
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
presence_penalty=-1.1, temperature=0.0,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content is None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": "0",
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
}, },
} }
] ]
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
temperature=0.0,
tool_choice="auto", tool_choice="auto",
presence_penalty=-1.1,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content is None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": "0",
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
}, },
} }
] ]
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
temperature=0.0,
tool_choice="get_current_weather", tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
assert response.choices[0].message.content is None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": "0",
"type": "function", "type": "function",
"function": { "function": {
"description": None, "description": None,
"name": "get_current_weather", "name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"}, "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
}, },
} }
] ]
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
max_tokens=100, max_tokens=100,
seed=1, seed=1,
tools=tools, tools=tools,
temperature=0.0,
tool_choice="get_current_weather", tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[ messages=[
{ {
"role": "system", "role": "system",
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses: async for response in responses:
count += 1 count += 1
assert count == 38 assert count == 48
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information( async def test_flash_llama_grammar_tools_insufficient_information(
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
): ):
responses = await flash_llama_grammar_tools.chat( responses = await flash_llama_grammar_tools.chat(
max_tokens=100, max_tokens=100,
seed=8, seed=24,
tools=tools, tools=tools,
tool_choice="auto", tool_choice="auto",
messages=[ messages=[
{ {
"role": "system", "role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION", "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
}, },
{ {
"role": "user", "role": "user",
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
) )
assert responses.choices[0].message.content is None assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [ assert (
{ responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
"function": { )
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -8,7 +8,7 @@ use nix::unistd::Pid;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::ffi::OsString; use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines}; use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path; use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio}; use std::process::{Child, Command, ExitStatus, Stdio};
@ -18,23 +18,134 @@ use std::sync::{mpsc, Arc};
use std::thread; use std::thread;
use std::thread::sleep; use std::thread::sleep;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::{fs, io}; use std::{
fs, io,
io::{Read, Write},
};
use thiserror::Error; use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter}; use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime; mod env_runtime;
fn get_config(
model_id: &str,
revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
let mut path = std::path::Path::new(model_id).to_path_buf();
let model_id = model_id.to_string();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some("flashdecoding".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct RawConfig { struct RawConfig {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
n_positions: Option<usize>, n_positions: Option<usize>,
model_type: Option<String>, model_type: Option<String>,
max_seq_len: Option<usize>, max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
} }
#[derive(Deserialize)]
struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Config { struct Config {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
} }
impl From<RawConfig> for Config { impl From<RawConfig> for Config {
@ -43,13 +154,39 @@ impl From<RawConfig> for Config {
.max_position_embeddings .max_position_embeddings
.or(other.max_seq_len) .or(other.max_seq_len)
.or(other.n_positions); .or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
(Some(hidden_size), _, Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
// Legacy
(_, Some(hidden_size), Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
_ => None,
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config { Config {
max_position_embeddings, max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
} }
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
#[serde(rename_all = "kebab-case")]
enum Quantization { enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model: /// 4 bit quantization. Requires a specific AWQ quantized model:
/// <https://hf.co/models?search=awq>. /// <https://hf.co/models?search=awq>.
@ -72,17 +209,17 @@ enum Quantization {
Marlin, Marlin,
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
/// but it is known that the model will be much slower to run than the native f16. /// but it is known that the model will be much slower to run than the native f16.
#[deprecated( // #[deprecated(
since = "1.1.0", // since = "1.1.0",
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases" // note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
)] // )]
Bitsandbytes, Bitsandbytes,
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
/// but it is known that the model will be much slower to run than the native f16. /// but it is known that the model will be much slower to run than the native f16.
BitsandbytesNF4, BitsandbytesNf4,
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
/// perplexity performance for you model /// perplexity performance for you model
BitsandbytesFP4, BitsandbytesFp4,
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
/// This dtype has native ops should be the fastest if available. /// This dtype has native ops should be the fastest if available.
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix /// This is currently not the fastest because of local unpacking + padding to satisfy matrix
@ -99,10 +236,10 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => { Quantization::Bitsandbytes => {
write!(f, "bitsandbytes") write!(f, "bitsandbytes")
} }
Quantization::BitsandbytesNF4 => { Quantization::BitsandbytesNf4 => {
write!(f, "bitsandbytes-nf4") write!(f, "bitsandbytes-nf4")
} }
Quantization::BitsandbytesFP4 => { Quantization::BitsandbytesFp4 => {
write!(f, "bitsandbytes-fp4") write!(f, "bitsandbytes-fp4")
} }
Quantization::Exl2 => { Quantization::Exl2 => {
@ -721,6 +858,7 @@ fn shard_manager(
.args(shard_args) .args(shard_args)
.env_clear() .env_clear()
.envs(envs) .envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.process_group(0) .process_group(0)
@ -742,12 +880,13 @@ fn shard_manager(
}; };
// Redirect STDOUT to the console // Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread //stdout tracing thread
thread::spawn(move || { thread::spawn(move || {
log_lines(shard_stdout_reader.lines()); log_lines(shard_stdout_reader);
}); });
// We read stderr in another thread as it seems that lines() can block in some cases // We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel(); let (err_sender, err_receiver) = mpsc::channel();
@ -756,6 +895,18 @@ fn shard_manager(
err_sender.send(line).unwrap_or(()); err_sender.send(line).unwrap_or(());
} }
}); });
// We read stdin in another thread as it seems that lines() can block in some cases
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
let mut ready = false; let mut ready = false;
let start_time = Instant::now(); let start_time = Instant::now();
@ -862,19 +1013,36 @@ impl PythonLogMessage {
} }
} }
impl TryFrom<&String> for PythonLogMessage { impl TryFrom<&[u8]> for PythonLogMessage {
type Error = serde_json::Error; type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> { fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value) serde_json::from_slice::<Self>(value)
} }
} }
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) { fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
for line in lines.map_while(Result::ok) { let mut buffer = vec![0u8; 8 * 4096];
match PythonLogMessage::try_from(&line) { let mut stdout = std::io::stdout();
Ok(log) => log.trace(), loop {
Err(_) => tracing::debug!("{line}"), let n = bufread.read(&mut buffer);
if let Ok(n) = n {
if n > 0 {
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
while let Some(line) = lines.next() {
match PythonLogMessage::try_from(line) {
Ok(log) => log.trace(),
// For interactive debugging ?
Err(_) => {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
}
stdout.flush().unwrap();
}
}
}
}
} }
} }
} }
@ -1034,7 +1202,7 @@ fn download_convert_model(
let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || { thread::spawn(move || {
log_lines(download_stdout.lines()); log_lines(download_stdout);
}); });
let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
@ -1085,6 +1253,7 @@ fn spawn_shards(
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize, max_total_tokens: usize,
max_input_tokens: usize, max_input_tokens: usize,
quantize: Option<Quantization>,
max_log_level: LevelFilter, max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1106,7 +1275,6 @@ fn spawn_shards(
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone(); let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate; let speculate = args.speculate;
let dtype = args.dtype; let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code; let trust_remote_code = args.trust_remote_code;
@ -1429,45 +1597,12 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args); tracing::info!("{:#?}", args);
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> { let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let model_id = args.model_id.clone(); let quantize = config.as_ref().and_then(|c| c.quantize);
let mut path = std::path::Path::new(&args.model_id).to_path_buf(); // Quantization usually means you're even more RAM constrained.
let filename = if !path.exists() { let max_default = 4096;
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("FLASH_DECODING", "1");
}
let config: Config = config.into();
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings { if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default { if max_position_embeddings > max_default {
let max = max_position_embeddings; let max = max_position_embeddings;
@ -1477,17 +1612,20 @@ fn main() -> Result<(), LauncherError> {
{ {
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
} }
Ok(max_default) max_default
} else { } else {
Ok(max_position_embeddings) max_position_embeddings
} }
} else { } else {
Err(Box::new(LauncherError::ArgumentValidation( max_default
"no max defined".to_string(),
)))
} }
} else {
max_default
}; };
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096); let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = { let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) { match (args.max_input_tokens, args.max_input_length) {
@ -1544,18 +1682,26 @@ fn main() -> Result<(), LauncherError> {
))); )));
} }
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) { if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
}
let quantize = args.quantize.or(quantize);
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(), (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
#[allow(deprecated)] #[allow(deprecated)]
( (
None, None,
Some( Some(
Quantization::Bitsandbytes Quantization::Bitsandbytes
| Quantization::BitsandbytesNF4 | Quantization::BitsandbytesNf4
| Quantization::BitsandbytesFP4, | Quantization::BitsandbytesFp4,
), ),
) => { ) => {
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them"); tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
vec![]
}
(None, Some(Quantization::Exl2)) => {
tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
vec![] vec![]
} }
_ => { _ => {
@ -1672,6 +1818,7 @@ fn main() -> Result<(), LauncherError> {
cuda_graphs, cuda_graphs,
max_total_tokens, max_total_tokens,
max_input_tokens, max_input_tokens,
quantize,
max_log_level, max_log_level,
shutdown.clone(), shutdown.clone(),
&shutdown_receiver, &shutdown_receiver,

84
nix/crate-overrides.nix Normal file
View File

@ -0,0 +1,84 @@
{ pkgs, nix-filter }:
let
filter = nix-filter.lib;
in
with pkgs;
defaultCrateOverrides
// {
aws-lc-rs = attrs: {
# aws-lc-rs does its own custom parsing of Cargo environment
# variables like DEP_.*_INCLUDE. However buildRustCrate does
# not use the version number, so the parsing fails.
postPatch = ''
substituteInPlace build.rs \
--replace-fail \
"assert!(!selected.is_empty()" \
"// assert!(!selected.is_empty()"
'';
};
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
grpc-metadata = attrs: {
src = filter {
root = ../backends/grpc-metadata;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-benchmark = attrs: {
src = filter {
root = ../benchmark;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-client = attrs: {
src = filter {
root = ../.;
include = with filter; [
isDirectory
(and (inDirectory "backends/client") (matchExt "rs"))
(and (inDirectory "proto") (matchExt "proto"))
];
};
postPatch = "cd backends/client";
buildInputs = [ protobuf ];
};
text-generation-launcher = attrs: {
src = filter {
root = ../launcher;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-router = attrs: {
src = filter {
root = ../router;
include = with filter; [
isDirectory
(matchExt "rs")
];
};
};
text-generation-router-v3 = attrs: {
# We need to do the src/source root dance so that the build
# has access to the protobuf file.
src = filter {
root = ../.;
include = with filter; [
isDirectory
(and (inDirectory "backends/v3") (matchExt "rs"))
(and (inDirectory "proto") (matchExt "proto"))
];
};
postPatch = "cd backends/v3";
buildInputs = [ protobuf ];
};
}

110
nix/server.nix Normal file
View File

@ -0,0 +1,110 @@
{
nix-filter,
buildPythonPackage,
poetry-core,
mypy-protobuf,
awq-inference-engine,
causal-conv1d,
eetq,
einops,
exllamav2,
fbgemm-gpu,
flashinfer,
flash-attn,
flash-attn-layer-norm,
flash-attn-rotary,
grpc-interceptor,
grpcio-reflection,
grpcio-status,
grpcio-tools,
hf-transfer,
loguru,
mamba-ssm,
marlin-kernels,
opentelemetry-api,
opentelemetry-exporter-otlp,
opentelemetry-instrumentation-grpc,
opentelemetry-semantic-conventions,
peft,
safetensors,
tokenizers,
torch,
sentencepiece,
transformers,
typer,
vllm,
}:
let
filter = nix-filter.lib;
in
buildPythonPackage {
name = "text-generation-server";
src = filter {
root = ../.;
include = with filter; [
isDirectory
(and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi")))
"server/pyproject.toml"
(and (inDirectory "proto/v3") (matchExt "proto"))
];
};
pyproject = true;
build-system = [ poetry-core ];
nativeBuildInputs = [ mypy-protobuf ];
pythonRelaxDeps = [
"einops"
"huggingface-hub"
"loguru"
"opentelemetry-instrumentation-grpc"
"sentencepiece"
"typer"
];
pythonRemoveDeps = [ "scipy" ];
dependencies = [
awq-inference-engine
eetq
causal-conv1d
einops
exllamav2
fbgemm-gpu
flashinfer
flash-attn
flash-attn-layer-norm
flash-attn-rotary
grpc-interceptor
grpcio-reflection
grpcio-status
grpcio-tools
hf-transfer
loguru
mamba-ssm
marlin-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions
peft
safetensors
sentencepiece
tokenizers
transformers
typer
vllm
];
prePatch = ''
python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \
--grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto
find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch server/text_generation_server/pb/__init__.py
cd server
'';
}

View File

@ -3,22 +3,23 @@ syntax = "proto3";
package generate.v3; package generate.v3;
service TextGenerationService { service TextGenerationService {
/// Model Info /// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {} rpc Info(InfoRequest) returns (InfoResponse) {}
/// Service discovery /// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {} rpc ServiceDiscovery(ServiceDiscoveryRequest)
/// Empties batch cache returns (ServiceDiscoveryResponse) {}
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse); /// Empties batch cache
/// Remove requests from a cached batch rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse); /// Remove requests from a cached batch
/// Warmup the model and compute max cache size rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
rpc Warmup (WarmupRequest) returns (WarmupResponse); /// Warmup the model and compute max cache size
/// Prefill batch and decode first token rpc Warmup(WarmupRequest) returns (WarmupResponse);
rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Prefill batch and decode first token
/// Decode token for a list of prefilled batches rpc Prefill(PrefillRequest) returns (PrefillResponse);
rpc Decode (DecodeRequest) returns (DecodeResponse); /// Decode token for a list of prefilled batches
/// Health check rpc Decode(DecodeRequest) returns (DecodeResponse);
rpc Health (HealthRequest) returns (HealthResponse); /// Health check
rpc Health(HealthRequest) returns (HealthResponse);
} }
message HealthRequest {} message HealthRequest {}
@ -28,240 +29,241 @@ message HealthResponse {}
message InfoRequest {} message InfoRequest {}
message InfoResponse { message InfoResponse {
bool requires_padding = 1; bool requires_padding = 1;
string dtype = 2; string dtype = 2;
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
} }
/// Empty request /// Empty request
message ServiceDiscoveryRequest {} message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse { message ServiceDiscoveryResponse {
/// Other shards urls /// Other shards urls
repeated string urls = 1; repeated string urls = 1;
} }
message ClearCacheRequest { message ClearCacheRequest {
/// Optional batch id /// Optional batch id
optional uint64 id = 1; optional uint64 id = 1;
} }
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
message Image { message Image {
/// Binary image data. /// Binary image data.
bytes data = 1; bytes data = 1;
/// Image MIME type. /// Image MIME type.
string mimetype = 2; string mimetype = 2;
} }
message InputChunk { message InputChunk {
oneof chunk { oneof chunk {
/// Plain text data /// Plain text data
string text = 1; string text = 1;
/// Image data /// Image data
Image image = 2; Image image = 2;
} }
} }
message Input { message Input { repeated InputChunk chunks = 1; }
repeated InputChunk chunks = 1;
}
enum GrammarType { enum GrammarType {
GRAMMAR_TYPE_NONE = 0; GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1; GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2; GRAMMAR_TYPE_REGEX = 2;
} }
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
/// restricting to the k highest probability elements /// restricting to the k highest probability elements
uint32 top_k = 2; uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3; float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off /// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4; float typical_p = 4;
/// apply sampling on the logits /// apply sampling on the logits
bool do_sample = 5; bool do_sample = 5;
/// random seed for sampling /// random seed for sampling
uint64 seed = 6; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 7; float repetition_penalty = 7;
/// frequency penalty /// frequency penalty
float frequency_penalty = 9; float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty) /// grammar (applied if not empty)
string grammar = 10; string grammar = 10;
/// grammar type /// grammar type
GrammarType grammar_type = 11; GrammarType grammar_type = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
/// Maximum number of generated tokens /// Maximum number of generated tokens
uint32 max_new_tokens = 1; uint32 max_new_tokens = 1;
/// Optional stopping sequences /// Optional stopping sequences
repeated string stop_sequences = 2; repeated string stop_sequences = 2;
/// Ignore end of sequence token /// Ignore end of sequence token
/// used for benchmarking /// used for benchmarking
bool ignore_eos_token = 3; bool ignore_eos_token = 3;
} }
message Request { message Request {
/// Request ID /// Request ID
uint64 id = 1; uint64 id = 1;
/// The generation context as chunks /// The generation context as chunks
Input input_chunks = 8; Input input_chunks = 8;
/// The generation context, stringified input_chunks /// The generation context, stringified input_chunks
string inputs = 2; string inputs = 2;
/// Context truncation /// Context truncation
uint32 truncate = 3; uint32 truncate = 3;
/// Next Token Chooser Parameters /// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4; NextTokenChooserParameters parameters = 4;
/// Stopping Criteria Parameters /// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs /// Return prefill logprobs
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// Paged attention blocks /// Paged attention blocks
repeated uint32 blocks = 9; repeated uint32 blocks = 9;
/// Paged attention slots /// Paged attention slots
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
} }
message Batch { message Batch {
/// Batch ID /// Batch ID
uint64 id = 1; uint64 id = 1;
/// Individual requests /// Individual requests
repeated Request requests = 2; repeated Request requests = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks /// Maximum number of Paged Attention blocks
uint32 max_blocks = 5; uint32 max_blocks = 5;
} }
message CachedBatch { message CachedBatch {
/// Batch ID /// Batch ID
uint64 id = 1; uint64 id = 1;
/// Individual requests ids /// Individual requests ids
repeated uint64 request_ids = 2; repeated uint64 request_ids = 2;
/// Batch size (==len(requests)) /// Batch size (==len(requests))
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
} }
enum FinishReason { enum FinishReason {
FINISH_REASON_LENGTH = 0; FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1; FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2; FINISH_REASON_STOP_SEQUENCE = 2;
} }
message GeneratedText { message GeneratedText {
/// Output /// Output
string text = 1; string text = 1;
/// Number of generated tokens /// Number of generated tokens
uint32 generated_tokens = 2; uint32 generated_tokens = 2;
/// Finish reason /// Finish reason
FinishReason finish_reason = 3; FinishReason finish_reason = 3;
/// Seed /// Seed
optional uint64 seed = 4; optional uint64 seed = 4;
} }
message Tokens { message Tokens {
/// Token IDs /// Token IDs
repeated uint32 ids = 1; repeated uint32 ids = 1;
/// Logprobs /// Logprobs
repeated float logprobs = 2; repeated float logprobs = 2;
/// tokens /// tokens
repeated string texts = 3; repeated string texts = 3;
/// special /// special
repeated bool is_special = 4; repeated bool is_special = 4;
} }
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
/// Prefill tokens (optional) /// Prefill tokens (optional)
Tokens prefill_tokens = 2; Tokens prefill_tokens = 2;
Tokens tokens = 3; Tokens tokens = 3;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 4; optional GeneratedText generated_text = 4;
/// Top tokens /// Top tokens
repeated Tokens top_tokens = 5; repeated Tokens top_tokens = 5;
} }
message FilterBatchRequest { message FilterBatchRequest {
/// Batch ID /// Batch ID
uint64 batch_id = 1; uint64 batch_id = 1;
/// Requests to keep /// Requests to keep
repeated uint64 request_ids = 2; repeated uint64 request_ids = 2;
} }
message FilterBatchResponse { message FilterBatchResponse {
/// Filtered Batch (cached) /// Filtered Batch (cached)
CachedBatch batch = 1; CachedBatch batch = 1;
} }
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
} }
message PrefillResponse { message PrefillResponse {
/// Generation /// Generation
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds /// Forward elapsed time in nanoseconds
uint64 forward_ns = 3; uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds /// Decode elapsed time in nanoseconds
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
} }
message DecodeRequest { message DecodeRequest {
/// Cached batches /// Cached batches
repeated CachedBatch batches = 1; repeated CachedBatch batches = 1;
} }
message DecodeResponse { message DecodeResponse {
/// Decodes /// Decodes
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional CachedBatch batch = 2; optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds /// Forward elapsed time in nanoseconds
uint64 forward_ns = 3; uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds /// Decode elapsed time in nanoseconds
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds /// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6; optional uint64 concat_ns = 6;
} }
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;
uint32 max_input_length = 2; uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3; uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4; uint32 max_total_tokens = 4;
} }
message WarmupResponse { message WarmupResponse {
/// Maximum number of tokens supported by the model /// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1; optional uint32 max_supported_total_tokens = 1;
} }

View File

@ -27,8 +27,14 @@ reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188" serde = "1.0.188"
serde_json = "1.0.107" serde_json = "1.0.107"
thiserror = "1.0.48" thiserror = "1.0.48"
tokenizers = { workspace = true} tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14" tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] } tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.40" tracing = "0.1.40"
@ -37,16 +43,22 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true } ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } init-tracing-opentelemetry = { version = "0.14.1", features = [
minijinja = { version = "2.0.2" } "opentelemetry-otlp",
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } ] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30" futures-util = "0.3.30"
regex = "1.10.3" regex = "1.10.3"
once_cell = "1.19.0" once_cell = "1.19.0"
image = "0.25.1" image = "0.25.1"
base64 = { workspace = true } base64 = { workspace = true }
sysinfo = "0.30.13" sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } uuid = { version = "1.9.1", default-features = false, features = [
"v4",
"fast-rng",
"macro-diagnostics",
] }
csv = "1.3.0" csv = "1.3.0"
ureq = "=2.9" ureq = "=2.9"

View File

@ -153,6 +153,7 @@ pub enum Config {
Bloom, Bloom,
Mpt, Mpt,
Gpt2, Gpt2,
Gptj,
GptNeox, GptNeox,
Phi, Phi,
#[serde(rename = "phi-msft")] #[serde(rename = "phi-msft")]

View File

@ -1,9 +1,8 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{ use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use minijinja::{Environment, ErrorKind, Template}; use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat; use minijinja_contrib::pycompat;
use std::collections::HashSet;
/// Raise a exception (custom function) used in the chat templates /// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> { pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
@ -16,6 +15,7 @@ pub(crate) struct ChatTemplate {
bos_token: Option<String>, bos_token: Option<String>,
eos_token: Option<String>, eos_token: Option<String>,
use_default_tool_template: bool, use_default_tool_template: bool,
variables: HashSet<String>,
} }
impl ChatTemplate { impl ChatTemplate {
@ -29,48 +29,70 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback); env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance. // leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env) let template = Box::leak(env)
.template_from_str(Box::leak(template_str)) .template_from_str(Box::leak(template_str))
.unwrap(); .unwrap();
// get the list of variables that are used in the template
let variables = template.undeclared_variables(true);
// check if the `tools` variable is used in the template
let use_default_tool_template = !variables.contains("tools");
tracing::debug!("Use default tool template: {}", use_default_tool_template);
Self { Self {
template, template,
bos_token: bos_token.map(|token| token.as_str().to_string()), bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()), eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template, use_default_tool_template,
variables,
} }
} }
pub(crate) fn apply( pub(crate) fn apply(
&self, &self,
guideline: Option<&str>,
mut messages: Vec<Message>, mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
if self.use_default_tool_template { // check if guideline is expected but not provided
if let Some(last_message) = messages.last_mut() { if self.variables.contains("guideline") && guideline.is_none() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { return Err(InferError::MissingTemplateVariable("guideline".to_string()));
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
} }
let tools = match tools_and_prompt {
Some((tools, tool_prompt)) => {
// check if the `tools` variable is used in the template
// if not, we need to append the tools to the last message
let text = if self.use_default_tool_template {
match serde_json::to_string(&tools) {
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
Err(e) => return Err(InferError::ToolError(e.to_string())),
}
} else {
// if the `tools` variable is used in the template, we just append the tool_prompt
format!("\n---\n{}", tool_prompt)
};
if let Some(last_message) = messages.last_mut() {
last_message.content.push(MessageChunk::Text { text });
}
Some(tools)
}
None => None,
};
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect(); let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template self.template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
guideline,
messages, messages,
bos_token: self.bos_token.as_deref(), bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(), eos_token: self.eos_token.as_deref(),
add_generation_prompt: true, add_generation_prompt: true,
tools: None, tools,
tools_prompt: None,
}) })
.map_err(InferError::TemplateError) .map_err(InferError::TemplateError)
} }
@ -80,7 +102,10 @@ impl ChatTemplate {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::chat_template::raise_exception; use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage}; use crate::infer::ChatTemplate;
use crate::{
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
};
use minijinja::Environment; use minijinja::Environment;
#[test] #[test]
@ -731,6 +756,19 @@ mod tests {
}, },
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
}, },
ChatTemplateTestItem {
name: "google/shieldgemma-9b",
chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n",
input: ChatTemplateInputs {
messages: example_chat_with_system.clone(),
add_generation_prompt: false,
bos_token: Some("<s>"),
eos_token: Some("</s>"),
guideline: Some("Do not use offensive language."),
..Default::default()
},
target: "<s>You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n<start_of_turn>\nHuman Question: I'd like to show off how chat templating works!\n<end_of_turn>\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n",
},
]; ];
#[allow(unused_variables)] // name is unused #[allow(unused_variables)] // name is unused
@ -755,4 +793,116 @@ mod tests {
assert_eq!(result, target); assert_eq!(result, target);
} }
} }
#[test]
fn test_chat_template_invalid_with_guideline() {
let ct = ChatTemplate::new(
"{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
// convert TextMessage to Message
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
},
Message {
name: None,
role: "assistant".to_string(),
content: MessageContent::SingleText(
"I'm doing great. How can I help you today?".to_string(),
),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText("Hello, how are you?".to_string()),
},
];
let result = ct.apply(None, msgs, None);
match result {
Ok(_) => panic!("Should have failed since no guideline is provided"),
Err(e) => {
assert_eq!(e.to_string(), "Missing template vatiable: guideline")
}
}
}
#[test]
fn test_chat_template_with_default_tool_template() {
let ct = ChatTemplate::new(
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
// convert TextMessage to Message
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"I'd like to show off how chat templating works!".to_string(),
),
},
Message {
name: None,
role: "assistant".to_string(),
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText("Just testing".to_string()),
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected);
}
#[test]
fn test_chat_template_with_custom_tool_template() {
// chat template from meta-llama/Meta-Llama-3.1-8B-Instruct
let ct = ChatTemplate::new(
"{{- bos_token }}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n".to_string(),
Some(TokenizerConfigToken::String("<s>".to_string())),
Some(TokenizerConfigToken::String("</s>".to_string())),
);
let msgs: Vec<Message> = vec![
Message {
name: None,
role: "system".to_string(),
content: MessageContent::SingleText(
"Youre a helpful assistant! Answer the users question best you can."
.to_string(),
),
},
Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"What is the weather like in Brooklyn, New York?".to_string(),
),
},
];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected);
}
} }

View File

@ -3,7 +3,7 @@ mod chat_template;
pub mod tool_grammar; pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType; use crate::Tool;
use crate::{ use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token, Message, PrefillToken, Token,
@ -120,10 +120,11 @@ impl Infer {
) -> Result<Option<tokenizers::Encoding>, InferError> { ) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request // Tokenize request
let inputs = request.inputs; let inputs = request.inputs;
let add_special_tokens = request.add_special_tokens;
let truncate = request.parameters.truncate; let truncate = request.parameters.truncate;
let encoding = self let encoding = self
.validation .validation
.tokenize(inputs, truncate) .tokenize(inputs, add_special_tokens, truncate)
.await .await
.map_err(|err| { .map_err(|err| {
tracing::error!("Tokenization {err}"); tracing::error!("Tokenization {err}");
@ -138,13 +139,14 @@ impl Infer {
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn apply_chat_template( pub(crate) fn apply_chat_template(
&self, &self,
guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.chat_template self.chat_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt) .apply(guideline.as_deref(), messages, tools_and_prompt)
.map_err(|e| { .map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1); metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}"); tracing::error!("{e}");
@ -336,6 +338,8 @@ pub enum InferError {
IncompleteGeneration, IncompleteGeneration,
#[error("Template error: {0}")] #[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error), TemplateError(#[from] minijinja::Error),
#[error("Missing template vatiable: {0}")]
MissingTemplateVariable(String),
#[error("Tool error: {0}")] #[error("Tool error: {0}")]
ToolError(String), ToolError(String),
} }
@ -348,6 +352,7 @@ impl InferError {
InferError::ValidationError(_) => "validation", InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation", InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error", InferError::TemplateError(_) => "template_error",
InferError::MissingTemplateVariable(_) => "missing_template_variable",
InferError::ToolError(_) => "tool_error", InferError::ToolError(_) => "tool_error",
} }
} }

View File

@ -1,5 +1,8 @@
use crate::infer::InferError; use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools}; use crate::{
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
ToolType,
};
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
@ -16,17 +19,38 @@ impl ToolGrammar {
} }
pub fn apply( pub fn apply(
tools: Option<Vec<Tool>>, tools: Vec<Tool>,
tool_choice: ToolChoice, tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> { ) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None // if no tools are provided, we return None
let tools = match tools { if tools.is_empty() {
Some(tools) if !tools.is_empty() => tools, return Ok((tools, None));
_ => return Ok(None), }
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
let mut tools = tools.clone();
// add the notify_error function to the tools
let notify_error = Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "notify_error".to_string(),
description: Some("Notify an error or issue".to_string()),
arguments: json!({
"type": "object",
"properties": {
"error": {
"type": "string",
"description": "The error or issue to notify"
}
},
"required": ["error"]
}),
},
};
tools.push(notify_error);
// if tools are provided and no tool_choice we default to the OneOf // if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice { let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => { ToolType::FunctionName(name) => {
@ -35,87 +59,57 @@ impl ToolGrammar {
ToolType::Function { function } => { ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?] vec![Self::find_tool_by_name(&tools, &function.name)?]
} }
ToolType::OneOf => tools, ToolType::OneOf => tools.clone(),
ToolType::NoTool => return Ok(None), ToolType::NoTool => return Ok((tools, None)),
}; };
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter() .iter()
.map(|tool| { .map(|tool| {
let func = tool.function.clone(); let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object let mut params = Map::new();
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert( params.insert(
"description".to_string(), "description".to_string(),
Value::String(func.description.clone().unwrap_or_default()), Value::String(func.description.unwrap_or_default()),
); );
// Ensure 'properties' exists and is an object let mut properties = Map::new();
let properties = params let mut required = vec![Value::String("_name".to_string())];
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert( properties.insert(
"_name".to_string(), "_name".to_string(),
json!({ json!({
"type": "string", "type": "string",
"const": func.name.clone(), "const": func.name.clone(),
// "description": "The name of the function"
}), }),
); );
// Check if 'required' exists, and it is an array. If not, create an empty array. if let Value::Object(args) = func.arguments {
let required = params if let Some(Value::Object(props)) = args.get("properties") {
.entry("required".to_string()) properties.extend(props.clone());
.or_insert_with(|| json!([])) }
.as_array_mut() if let Some(Value::Array(reqs)) = args.get("required") {
.unwrap(); required.extend(reqs.clone());
}
// Add 'name' to the 'required' array if it is not already present params.insert(
if !required.iter().any(|r| r == "_name") { "additionalProperties".to_string(),
required.push(json!("_name")); Value::Bool(
args.get("additionalProperties").and_then(|v| v.as_str())
== Some("true"),
),
);
} }
params.insert("properties".to_string(), Value::Object(properties));
params.insert("required".to_string(), Value::Array(required));
(func.name, Value::Object(params)) (func.name, Value::Object(params))
}) })
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect(); .collect();
let tools = Tools { let tool_schema = JsonSchemaTool {
functions_map: FunctionsMap { functions }, functions_map: FunctionsMap { functions },
properties: Properties { properties: Properties {
function: tools_to_use function: tools_to_use
@ -123,13 +117,10 @@ impl ToolGrammar {
.map(|tool| FunctionRef { .map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()), ref_path: format!("#/$functions/{}", tool.function.name.clone()),
}) })
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(), .collect(),
}, },
}; };
Ok(Some(tools)) Ok((tools, Some(tool_schema)))
} }
} }

View File

@ -205,6 +205,13 @@ impl State {
} }
} }
if let Some(max_size) = max_size {
if max_size == 0 {
tracing::debug!("No capacity");
return None;
}
}
// Pad prefill_token_budget to be a multiple of block size // Pad prefill_token_budget to be a multiple of block size
let prefill_token_budget = let prefill_token_budget =
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

View File

@ -1,10 +1,10 @@
/// Batching and inference logic /// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue}; use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{ use crate::infer::{
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
}; };
use crate::validation::ValidGenerateRequest; use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token}; use crate::{Attention, FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@ -40,12 +40,18 @@ impl BackendV2 {
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { let attention = if let Ok(attention) = std::env::var("ATTENTION") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") attention
.parse()
.expect(&format!("Invalid attention was specified :`{attention}`"))
} else { } else {
false Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
}; };
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate); let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
@ -161,8 +167,8 @@ pub(crate) async fn batching_task(
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)

View File

@ -15,6 +15,45 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance { pub(crate) struct VertexInstance {
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
@ -619,7 +658,7 @@ impl ChatCompletion {
message, message,
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.format(true),
}], }],
usage: Usage { usage: Usage {
prompt_tokens: details.prefill.len() as u32, prompt_tokens: details.prefill.len() as u32,
@ -811,10 +850,10 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>, pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools /// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")] #[serde(default)]
#[schema( #[schema(
nullable = true, nullable = true,
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"" example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)] )]
pub tool_prompt: Option<String>, pub tool_prompt: Option<String>,
@ -829,12 +868,15 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>, pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
} }
fn default_tool_prompt() -> Option<String> { pub fn default_tool_prompt() -> String {
Some( "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
)
} }
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
@ -876,7 +918,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
} }
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)] #[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
pub struct Tools { pub struct JsonSchemaTool {
#[serde(flatten)] #[serde(flatten)]
functions_map: FunctionsMap, functions_map: FunctionsMap,
properties: Properties, properties: Properties,
@ -934,8 +976,8 @@ pub(crate) struct ChatTemplateInputs<'a> {
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
tools: Option<&'a str>, tools: Option<Vec<Tool>>,
tools_prompt: Option<&'a str>, guideline: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
@ -981,8 +1023,10 @@ impl MessageContent {
pub fn push(&mut self, chunk: MessageChunk) { pub fn push(&mut self, chunk: MessageChunk) {
match self { match self {
MessageContent::SingleText(text) => { MessageContent::SingleText(text) => {
*self = *self = MessageContent::MultipleChunks(vec![
MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); MessageChunk::Text { text: text.clone() },
chunk,
]);
} }
MessageContent::MultipleChunks(chunks) => { MessageContent::MultipleChunks(chunks) => {
chunks.push(chunk); chunks.push(chunk);
@ -1038,6 +1082,16 @@ pub(crate) struct GenerateRequest {
pub inputs: String, pub inputs: String,
#[serde(default = "default_parameters")] #[serde(default = "default_parameters")]
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true", skip)]
pub add_special_tokens: bool,
}
fn default_true() -> bool {
true
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
@ -1055,6 +1109,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self { fn from(req: CompatGenerateRequest) -> Self {
Self { Self {
inputs: req.inputs, inputs: req.inputs,
add_special_tokens: true,
parameters: req.parameters, parameters: req.parameters,
} }
} }
@ -1117,6 +1172,15 @@ impl std::fmt::Display for FinishReason {
} }
} }
impl FinishReason {
pub fn format(&self, use_stop: bool) -> String {
match self {
FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(),
_ => self.to_string(),
}
}
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence { pub(crate) struct BestOfSequence {
#[schema(example = "test")] #[schema(example = "test")]
@ -1157,6 +1221,12 @@ pub(crate) struct GenerateResponse {
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize, ToSchema)]
pub(crate) struct ChatTokenizeResponse {
pub(crate) tokenize_response: TokenizeResponse,
pub(crate) templated_text: String,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
#[serde(transparent)] #[serde(transparent)]
pub(crate) struct TokenizeResponse(Vec<SimpleToken>); pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
@ -1169,6 +1239,8 @@ pub(crate) struct StreamDetails {
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(nullable = true, example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -1189,6 +1261,34 @@ pub(crate) struct ErrorResponse {
pub error_type: String, pub error_type: String,
} }
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelInfo {
#[schema(example = "gpt2")]
pub id: String,
#[schema(example = "model")]
pub object: String,
#[schema(example = 1686935002)]
pub created: u64,
#[schema(example = "openai")]
pub owned_by: String,
}
#[derive(Serialize, Deserialize, ToSchema)]
pub(crate) struct ModelsInfo {
#[schema(example = "list")]
pub object: String,
pub data: Vec<ModelInfo>,
}
impl Default for ModelsInfo {
fn default() -> Self {
ModelsInfo {
object: "list".to_string(),
data: Vec::new(),
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -1296,6 +1396,35 @@ mod tests {
); );
} }
#[test]
fn test_message_content_append() {
let mut content = MessageContent::SingleText("Initial text".to_string());
let chunk = MessageChunk::Text {
text: "Additional text".to_string(),
};
content.push(chunk);
match content {
MessageContent::MultipleChunks(chunks) => {
assert_eq!(chunks.len(), 2);
assert_eq!(
chunks[0],
MessageChunk::Text {
text: "Initial text".to_string()
}
);
assert_eq!(
chunks[1],
MessageChunk::Text {
text: "Additional text".to_string()
}
);
}
_ => panic!("Expected MultipleChunks, but got a different variant"),
}
}
#[test] #[test]
fn test_chat_request() { fn test_chat_request() {
let json = json!({ let json = json!({

View File

@ -8,6 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse};
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -23,6 +24,7 @@ use crate::{
VertexResponse, VertexResponse,
}; };
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream; use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
@ -115,6 +117,133 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0) Json(info.0)
} }
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/v1/models",
responses(
(status = 200, description = "Served model info", body = ModelInfo),
(status = 404, description = "Model not found", body = ErrorResponse),
)
)]
#[instrument(skip(info))]
/// Get model info
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
Json(ModelsInfo {
data: vec![ModelInfo {
id: info.0.model_id.clone(),
object: "model".to_string(),
created: 0, // TODO: determine how to get this
owned_by: info.0.model_id.clone(),
}],
..Default::default()
})
}
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/chat_tokenize",
request_body = ChatRequest,
responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse))
)]
async fn get_chat_tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
..
} = req;
let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, _using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
let generate_request = GenerateRequest {
inputs,
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: max_tokens,
return_full_text: None,
stop: stop.unwrap_or_default(),
truncate: None,
watermark: false,
details: false,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: _grammar,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
};
let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
if let Some(encoding) = encoding {
let tokens: Vec<SimpleToken> = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| {
let text = input
.chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken {
id,
text,
start,
stop,
}
})
.collect();
let resp = ChatTokenizeResponse {
tokenize_response: TokenizeResponse(tokens),
templated_text: input,
};
Ok((HeaderMap::new(), Json(resp)))
} else {
Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
error_type: "no fast tokenizer".to_string(),
}),
))
}
}
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
@ -429,7 +558,7 @@ async fn generate_stream_internal(
} else { } else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives // Keep permit as long as generate_stream lives
Ok((_permit, _input_length, response_stream)) => { Ok((_permit, input_length, response_stream)) => {
let mut index = 0; let mut index = 0;
let mut response_stream = Box::pin(response_stream); let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream // Server-Sent Event stream
@ -472,6 +601,7 @@ async fn generate_stream_internal(
finish_reason: generated_text.finish_reason, finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed, seed: generated_text.seed,
input_length,
}), }),
false => None, false => None,
}; };
@ -649,6 +779,7 @@ async fn completions(
.iter() .iter()
.map(|prompt| GenerateRequest { .map(|prompt| GenerateRequest {
inputs: prompt.to_string(), inputs: prompt.to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -697,21 +828,46 @@ async fn completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
event let message = match stream_token.details {
.json_data(Completion::Chunk(Chunk { Some(details) => {
id: "".to_string(), let completion_tokens = details.generated_tokens;
created: current_time, let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Completion::Final(CompletionFinal {
id: String::new(),
created: current_time,
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
choices: vec![CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: stream_token.token.text,
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens,
},
})
}
None => Completion::Chunk(Chunk {
id: String::new(),
created: current_time,
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
finish_reason: "".to_string(), finish_reason: String::new(),
index: index as u32, index: index as u32,
logprobs: None, logprobs: None,
text: stream_token.token.text, text: stream_token.token.text,
}], }],
model: model_id.clone(), model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(), system_fingerprint: system_fingerprint.clone(),
})) }),
};
event
.json_data(message)
.unwrap_or_else(|_e| Event::default()) .unwrap_or_else(|_e| Event::default())
}; };
@ -919,7 +1075,7 @@ async fn completions(
total_tokens += details.prefill.len() as u32 + details.generated_tokens; total_tokens += details.prefill.len() as u32 + details.generated_tokens;
Ok(CompletionComplete { Ok(CompletionComplete {
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.format(true),
index: index as u32, index: index as u32,
logprobs: None, logprobs: None,
text: generation.generated_text, text: generation.generated_text,
@ -1021,80 +1177,36 @@ async fn chat_completions(
tool_prompt, tool_prompt,
temperature, temperature,
response_format, response_format,
guideline,
.. ..
} = req; } = req;
let repetition_penalty = presence_penalty.map(|x| x + 2.0); let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100)); let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false); let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default(); let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default(); let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0 // enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature { let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None), Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other), other => (true, other),
}; };
let (inputs, grammar, using_tools) = prepare_chat_input(
// response_format and tools are mutually exclusive &infer,
if response_format.is_some() && tools.as_ref().is_some() { response_format,
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); tools,
return Err(( tool_choice,
StatusCode::UNPROCESSABLE_ENTITY, &tool_prompt,
Json(ErrorResponse { guideline,
error: "Grammar and tools are mutually exclusive".to_string(), messages,
error_type: "grammar and tools".to_string(), )?;
}),
));
}
// extract tool grammar if present
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
// determine the appropriate arguments for apply_chat_template
let tools_grammar_prompt = tool_grammar
.as_ref()
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
let (tools_grammar_prompt, grammar) = match response_format {
Some(response_format) => (None, Some(response_format)),
None => (
tools_grammar_prompt.clone(),
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
),
};
// apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: err.to_string(),
error_type: err.error_type().to_string(),
}),
));
}
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature, temperature,
@ -1138,7 +1250,7 @@ async fn chat_completions(
}); });
// replace the content with the tool calls if grammar is present // replace the content with the tool calls if grammar is present
let (content, tool_calls) = if tool_grammar.is_some() { let (content, tool_calls) = if using_tools {
(None, Some(vec![stream_token.token.text])) (None, Some(vec![stream_token.token.text]))
} else { } else {
let content = if !stream_token.token.special { let content = if !stream_token.token.special {
@ -1159,7 +1271,7 @@ async fn chat_completions(
tool_calls, tool_calls,
current_time, current_time,
logprobs, logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.format(true)),
), ),
)) ))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
@ -1192,10 +1304,14 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let (tool_calls, output) = if tool_grammar.is_some() { let (tool_calls, output) = if using_tools {
let gen_text_value: Value = serde_json::from_str(&generation.generated_text) let gen_text_value: Value =
.map_err(|e| InferError::ToolError(e.to_string()))?; serde_json::from_str(&generation.generated_text).map_err(|e| {
InferError::ToolError(format!(
"Failed to parse generated text: {} {:?}",
e, generation.generated_text
))
})?;
let function = gen_text_value.get("function").ok_or(InferError::ToolError( let function = gen_text_value.get("function").ok_or(InferError::ToolError(
"No function found in generated text".to_string(), "No function found in generated text".to_string(),
))?; ))?;
@ -1297,6 +1413,7 @@ async fn vertex_compatibility(
.map(|instance| { .map(|instance| {
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: instance.inputs.clone(), inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
do_sample: true, do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
@ -1360,8 +1477,11 @@ async fn tokenize(
.iter() .iter()
.zip(encoding.get_offsets()) .zip(encoding.get_offsets())
.map(|(&id, &(start, stop))| { .map(|(&id, &(start, stop))| {
let text: String = let text = input
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); .chars()
.skip(start)
.take(stop - start)
.collect::<String>();
SimpleToken { SimpleToken {
id, id,
text, text,
@ -1409,6 +1529,7 @@ chat_completions,
completions, completions,
tokenize, tokenize,
metrics, metrics,
openai_get_model_info,
), ),
components( components(
schemas( schemas(
@ -1461,6 +1582,7 @@ ToolCall,
Function, Function,
FunctionDefinition, FunctionDefinition,
ToolChoice, ToolChoice,
ModelInfo,
) )
), ),
tags( tags(
@ -1913,6 +2035,120 @@ async fn start(
.install_recorder() .install_recorder()
.expect("failed to install metrics recorder"); .expect("failed to install metrics recorder");
// Metrics descriptions
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
metrics::describe_histogram!(
"tgi_request_duration",
metrics::Unit::Seconds,
"Request duration"
);
metrics::describe_histogram!(
"tgi_request_validation_duration",
metrics::Unit::Seconds,
"Request validation duration"
);
metrics::describe_histogram!(
"tgi_request_queue_duration",
metrics::Unit::Seconds,
"Request queue duration"
);
metrics::describe_histogram!(
"tgi_request_inference_duration",
metrics::Unit::Seconds,
"Request inference duration"
);
metrics::describe_histogram!(
"tgi_request_mean_time_per_token_duration",
metrics::Unit::Seconds,
"Mean time per token per request"
);
metrics::describe_histogram!(
"tgi_request_generated_tokens",
metrics::Unit::Count,
"Generated tokens per request"
);
metrics::describe_counter!(
"tgi_batch_inference_count",
metrics::Unit::Count,
"Inference calls per method (prefill or decode)"
);
metrics::describe_counter!(
"tgi_request_count",
metrics::Unit::Count,
"Total number of requests"
);
metrics::describe_counter!(
"tgi_batch_inference_success",
metrics::Unit::Count,
"Number of successful inference calls per method (prefill or decode)"
);
metrics::describe_gauge!(
"tgi_batch_current_size",
metrics::Unit::Count,
"Current batch size"
);
metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size");
metrics::describe_gauge!(
"tgi_batch_current_max_tokens",
metrics::Unit::Count,
"Maximum tokens for the current batch"
);
metrics::describe_histogram!(
"tgi_request_max_new_tokens",
metrics::Unit::Count,
"Maximum new tokens per request"
);
metrics::describe_histogram!(
"tgi_batch_inference_duration",
metrics::Unit::Seconds,
"Batch inference duration"
);
metrics::describe_histogram!(
"tgi_batch_forward_duration",
metrics::Unit::Seconds,
"Batch forward duration per method (prefill or decode)"
);
metrics::describe_histogram!(
"tgi_request_skipped_tokens",
metrics::Unit::Count,
"Speculated tokens per request"
);
metrics::describe_histogram!(
"tgi_batch_filter_duration",
metrics::Unit::Seconds,
"Time spent filtering batches and sending generated tokens per method (prefill or decode)"
);
metrics::describe_histogram!(
"tgi_request_queue_duration",
metrics::Unit::Seconds,
"Time spent in the queue per request"
);
metrics::describe_histogram!(
"tgi_request_validation_duration",
metrics::Unit::Seconds,
"Time spent validating the request"
);
metrics::describe_histogram!(
"tgi_request_duration",
metrics::Unit::Seconds,
"Total time spent processing the request"
);
metrics::describe_histogram!(
"tgi_batch_decode_duration",
metrics::Unit::Seconds,
"Time spent decoding a batch per method (prefill or decode)"
);
metrics::describe_histogram!(
"tgi_request_input_length",
metrics::Unit::Count,
"Input token length per request"
);
metrics::describe_histogram!(
"tgi_batch_next_size",
metrics::Unit::Count,
"Batch size of the next batch"
);
// CORS layer // CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any()); let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new() let cors_layer = CorsLayer::new()
@ -2036,10 +2272,12 @@ async fn start(
} }
let info_routes = Router::new() let info_routes = Router::new()
.route("/", get(health)) .route("/", get(health))
.route("/chat_tokenize", post(get_chat_tokenize))
.route("/info", get(get_model_info)) .route("/info", get(get_model_info))
.route("/health", get(health)) .route("/health", get(health))
.route("/ping", get(health)) .route("/ping", get(health))
.route("/metrics", get(metrics)); .route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route // Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled { let aws_sagemaker_route = if messages_api_enabled {
@ -2232,6 +2470,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR, InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY, InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
}; };
@ -2332,3 +2571,157 @@ fn create_post_processor(
Ok(post_processor) Ok(post_processor)
} }
type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
tool_prompt: &str,
guideline: Option<String>,
messages: Vec<Message>,
) -> Result<PreparedInput, InferError> {
if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
"Grammar and tools are mutually exclusive".into(),
));
}
// when response_format is set, tools are not included when applying the chat template to generate inputs
if let Some(format) = response_format {
let inputs = infer.apply_chat_template(guideline, messages, None)?;
return Ok((inputs, Some(format), false));
}
// when no response_format is set and tools are included, apply the chat template with the tools
// to generate inputs
if let Some(tools) = tools {
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
let grammar = tool_schema
.as_ref()
.map(|t| GrammarType::Json(serde_json::json!(t)));
let inputs: String = infer.apply_chat_template(
guideline,
messages,
Some((updated_tools, tool_prompt.into())),
)?;
return Ok((inputs, grammar, tool_schema.is_some()));
}
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, messages, None)?;
Ok((inputs, None, false))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ChatTemplateVersions;
use crate::HubTokenizerConfig;
use crate::TokenizerConfigToken;
use crate::Tool;
use serde_json::json;
#[test]
fn test_prepare_chat_input() {
// Mock Backend to avoid network requests
struct MockBackend;
impl Backend for MockBackend {
fn schedule(
&self,
_request: crate::validation::ValidGenerateRequest,
) -> Result<
tokio_stream::wrappers::UnboundedReceiverStream<
Result<InferStreamResponse, InferError>,
>,
InferError,
> {
unimplemented!("Never called in this test");
}
fn health<'a, 'async_trait>(
&'a self,
_current_health: bool,
) -> core::pin::Pin<
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
>
where
'a: 'async_trait,
Self: 'async_trait,
{
unimplemented!("Never called in this test");
}
}
let backend = MockBackend {};
let mut tokenizer_config = HubTokenizerConfig::default();
// mock tokenizer config values
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
tokenizer_config.chat_template = Some(
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
);
let infer = Infer::new(
backend,
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
1,
tokenizer_config,
HubProcessorConfig::default(),
);
let response_format = None;
let tools = Some(vec![Tool {
r#type: "function".to_string(),
function: FunctionDefinition {
name: "get_current_weather".to_string(),
description: Some("Get the current weather".to_string()),
arguments: json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": ["location", "format"]
}),
},
}]);
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
let guideline = None;
let messages = vec![Message {
name: None,
role: "user".to_string(),
content: MessageContent::SingleText(
"What is the weather like in New York?".to_string(),
),
}];
let result = prepare_chat_input(
&infer,
response_format,
tools,
ToolChoice(None),
tool_prompt,
guideline,
messages,
);
assert!(result.is_ok());
let (inputs, _grammar, using_tools) = result.unwrap();
assert_eq!(using_tools, true);
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
}
}

View File

@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
use std::iter; use std::iter;
use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -94,6 +95,7 @@ impl Validation {
pub async fn tokenize( pub async fn tokenize(
&self, &self,
inputs: String, inputs: String,
add_special_tokens: bool,
truncate: Option<usize>, truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> { ) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
@ -103,7 +105,11 @@ impl Validation {
// Send request to the background validation task // Send request to the background validation task
// Unwrap is safe here // Unwrap is safe here
sender sender
.send(((inputs, truncate), response_sender, Span::current())) .send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
@ -115,15 +121,20 @@ impl Validation {
} }
} }
#[allow(clippy::type_complexity)]
#[instrument(skip(self, inputs))] #[instrument(skip(self, inputs))]
async fn validate_input( async fn validate_input(
&self, &self,
inputs: String, inputs: String,
add_special_tokens: bool,
truncate: Option<usize>, truncate: Option<usize>,
max_new_tokens: Option<u32>, max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> { ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
.await?
{
// Create response channel // Create response channel
let input_length = if let Some(truncate) = truncate { let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate) std::cmp::min(encoding.len(), truncate)
@ -156,8 +167,11 @@ impl Validation {
)); ));
} }
let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64); metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, Some(input_ids), input_length, max_new_tokens))
} }
// Return inputs without validation // Return inputs without validation
else { else {
@ -180,7 +194,12 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize); input_length = input_length.saturating_sub(max_new_tokens as usize);
} }
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens)) Ok((
vec![Chunk::Text(inputs)],
None,
input_length,
max_new_tokens,
))
} }
} }
@ -314,8 +333,13 @@ impl Validation {
.unwrap_or(Ok(None))?; .unwrap_or(Ok(None))?;
// Validate inputs // Validate inputs
let (inputs, input_length, max_new_tokens) = self let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(
request.inputs,
request.add_special_tokens,
truncate,
max_new_tokens,
)
.await?; .await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
@ -391,6 +415,8 @@ impl Validation {
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs, inputs,
input_ids: input_ids.map(Arc::new),
add_special_tokens: request.add_special_tokens,
decoder_input_details, decoder_input_details,
input_length: input_length as u32, input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
@ -439,12 +465,15 @@ fn tokenizer_worker(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) { ) {
// Loop over requests // Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| { parent_span.in_scope(|| {
response_tx response_tx
.send(prepare_input( .send(prepare_input(
inputs, inputs,
truncate, truncate,
add_special_tokens,
&tokenizer, &tokenizer,
config.as_ref(), config.as_ref(),
preprocessor_config.as_ref(), preprocessor_config.as_ref(),
@ -581,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
fn prepare_input( fn prepare_input(
inputs: String, inputs: String,
_truncate: Option<usize>, _truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &Tokenizer, tokenizer: &Tokenizer,
config: Option<&Config>, config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>, preprocessor_config: Option<&HubPreprocessorConfig>,
@ -618,14 +648,14 @@ fn prepare_input(
// Get the number of tokens in the input // Get the number of tokens in the input
let encoding = tokenizer let encoding = tokenizer
.encode(tokenizer_query, true) .encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?; .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks)) Ok((encoding, input_chunks))
} }
type TokenizerRequest = ( type TokenizerRequest = (
(String, Option<usize>), (String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>, oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span, Span,
); );
@ -707,8 +737,10 @@ pub struct ValidStoppingParameters {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ValidGenerateRequest { pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>, pub inputs: Vec<Chunk>,
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32, pub input_length: u32,
pub truncate: u32, pub truncate: u32,
pub add_special_tokens: bool,
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: ValidParameters, pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters, pub stopping_parameters: ValidStoppingParameters,
@ -815,11 +847,11 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await .await
{ {
// Err(ValidationError::MaxNewTokens(1, 10)) => (), // Err(ValidationError::MaxNewTokens(1, 10)) => (),
Ok((_s, 0, 10)) => (), Ok((_s, _, 0, 10)) => (),
r => panic!("Unexpected not max new tokens: {r:?}"), r => panic!("Unexpected not max new tokens: {r:?}"),
} }
} }
@ -850,7 +882,7 @@ mod tests {
let max_new_tokens = 10; let max_new_tokens = 10;
match validation match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens)) .validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await .await
{ {
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
@ -884,6 +916,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: Some(2), best_of: Some(2),
do_sample: false, do_sample: false,
@ -923,6 +956,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(1.0), top_p: Some(1.0),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -938,6 +972,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: Some(0.99), top_p: Some(0.99),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -953,6 +988,7 @@ mod tests {
let valid_request = validation let valid_request = validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_p: None, top_p: None,
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -991,6 +1027,7 @@ mod tests {
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(5), top_n_tokens: Some(5),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1006,6 +1043,7 @@ mod tests {
validation validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(4), top_n_tokens: Some(4),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1018,6 +1056,7 @@ mod tests {
validation validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: Some(0), top_n_tokens: Some(0),
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1030,6 +1069,7 @@ mod tests {
let valid_request = validation let valid_request = validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: "Hello".to_string(), inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
top_n_tokens: None, top_n_tokens: None,
max_new_tokens: Some(5), max_new_tokens: Some(5),
@ -1078,6 +1118,7 @@ mod tests {
let chunks = match validation let chunks = match validation
.tokenize( .tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF), format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
true,
None, None,
) )
.await .await
@ -1137,6 +1178,7 @@ mod tests {
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF PIXEL_GIF, PIXEL_GIF
), ),
true,
None, None,
) )
.await .await

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: June 13, 2024 # Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/ # https://releases.rs/docs/1.79.0/
channel = "1.79.0" channel = "1.80.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -6,6 +6,8 @@ include Makefile-eetq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica include Makefile-lorax-punica
include Makefile-fbgemm include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests

12
server/Makefile-exllamav2 Normal file
View File

@ -0,0 +1,12 @@
exllamav2_commit := v0.1.8
build-exllamav2:
git clone https://github.com/turboderp/exllamav2.git exllamav2 && \
cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \
git submodule update --init --recursive && \
pip install -r requirements.txt && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build
install-exllamav2: build-exllamav2
cd exllamav2/ && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install

View File

@ -1,4 +1,4 @@
fbgemm_commit := ddac8dd9fc0bee70a3f456df68b8aac38576c856 fbgemm_commit := v0.8.0
build-fbgemm: build-fbgemm:
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ git clone https://github.com/pytorch/FBGEMM.git fbgemm && \

View File

@ -0,0 +1,2 @@
install-flashinfer:
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4

View File

@ -1,10 +1,7 @@
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
extra_compile_args = ["-std=c++17"] extra_compile_args = ["-std=c++17"]
if not torch.version.hip:
extra_compile_args.append("-arch=compute_80")
setup( setup(
name="custom_kernels", name="custom_kernels",

69
server/poetry.lock generated
View File

@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras] [package.extras]
dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"] dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "2.1.5" version = "2.1.5"
@ -1207,6 +1231,17 @@ torch = "*"
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]] [[package]]
name = "mpmath" name = "mpmath"
version = "1.3.0" version = "1.3.0"
@ -2277,6 +2312,20 @@ files = [
[package.dependencies] [package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pygments"
version = "2.18.0"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
files = [
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
]
[package.extras]
windows-terminal = ["colorama (>=0.4.6)"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "7.4.4" version = "7.4.4"
@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.7.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]] [[package]]
name = "rpds-py" name = "rpds-py"
version = "0.19.0" version = "0.19.0"
@ -3584,4 +3651,4 @@ torch = ["torch"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.13" python-versions = ">=3.9,<3.13"
content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1" content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd"

View File

@ -46,6 +46,7 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
rich = "^13.7.1"
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]

View File

@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -16,6 +16,8 @@ huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-common==1.25.0 ; python_version >= "3.9" and python_version < "3.13"
@ -32,9 +34,11 @@ pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,7 +1,10 @@
import pytest import pytest
import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture @pytest.fixture
def default_pb_parameters(): def default_pb_parameters():

Some files were not shown because too many files have changed in this diff Show More