fixed merge conflicts
This commit is contained in:
commit
058162685f
|
@ -32,10 +32,6 @@ jobs:
|
|||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
# This is used to complete the identity challenge
|
||||
# with sigstore/fulcio when running outside of PRs.
|
||||
id-token: write
|
||||
security-events: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
@ -50,6 +46,7 @@ jobs:
|
|||
export label_extension=""
|
||||
export docker_devices=""
|
||||
export runs_on="aws-g6-12xlarge-plus-priv"
|
||||
export platform=""
|
||||
;;
|
||||
rocm)
|
||||
export dockerfile="Dockerfile_amd"
|
||||
|
@ -58,12 +55,21 @@ jobs:
|
|||
# TODO Re-enable when they pass.
|
||||
# export runs_on="amd-gpu-tgi"
|
||||
export runs_on="ubuntu-latest"
|
||||
export platform=""
|
||||
;;
|
||||
intel)
|
||||
intel-xpu)
|
||||
export dockerfile="Dockerfile_intel"
|
||||
export label_extension="-intel"
|
||||
export label_extension="-intel-xpu"
|
||||
export docker_devices=""
|
||||
export runs_on="ubuntu-latest"
|
||||
export platform="xpu"
|
||||
;;
|
||||
intel-cpu)
|
||||
export dockerfile="Dockerfile_intel"
|
||||
export label_extension="-intel-cpu"
|
||||
export docker_devices=""
|
||||
export runs_on="ubuntu-latest"
|
||||
export platform="cpu"
|
||||
;;
|
||||
esac
|
||||
echo $dockerfile
|
||||
|
@ -71,8 +77,10 @@ jobs:
|
|||
echo $label_extension
|
||||
echo $docker_devices
|
||||
echo $runs_on
|
||||
echo $platform
|
||||
echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV
|
||||
echo "LABEL=${label_extension}" >> $GITHUB_ENV
|
||||
echo "PLATFORM=${platform}" >> $GITHUB_ENV
|
||||
echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV
|
||||
echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV
|
||||
echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV
|
||||
|
@ -139,6 +147,7 @@ jobs:
|
|||
build-args: |
|
||||
GIT_SHA=${{ env.GITHUB_SHA }}
|
||||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||
PLATFORM=${{ env.PLATFORM }}
|
||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||
|
@ -159,7 +168,7 @@ jobs:
|
|||
group: ${{ needs.build-and-push.outputs.runs_on }}
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
env:
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
@ -11,7 +11,7 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yaml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
|
|
@ -37,8 +37,11 @@ jobs:
|
|||
# fail-fast is true by default
|
||||
fail-fast: false
|
||||
matrix:
|
||||
hardware: ["cuda", "rocm", "intel"]
|
||||
hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
|
||||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
with:
|
||||
hardware: ${{ matrix.hardware }}
|
||||
# https://github.com/actions/runner/issues/2206
|
||||
|
|
|
@ -35,7 +35,7 @@ jobs:
|
|||
with:
|
||||
# Released on: 02 May, 2024
|
||||
# https://releases.rs/docs/1.78.0/
|
||||
toolchain: 1.79.0
|
||||
toolchain: 1.80.0
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install Protoc
|
||||
|
|
|
@ -9,7 +9,7 @@ backends/client/src/v3/pb
|
|||
|
||||
# ROCm auto-generated files
|
||||
*.hip
|
||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||
server/exllamav2
|
||||
server/exllama_kernels/exllama_kernels/hip/
|
||||
server/exllama_kernels/exllama_kernels/hip_func/
|
||||
*_hip.cuh
|
||||
|
@ -18,3 +18,7 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
|||
|
||||
data/
|
||||
load_tests/*.json
|
||||
server/fbgemmm
|
||||
|
||||
.direnv/
|
||||
.venv/
|
||||
|
|
|
@ -5,7 +5,7 @@ repos:
|
|||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
exclude: docs/source/basic_tutorials/launcher.md
|
||||
exclude: docs/source/reference/launcher.md
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
|
|
|
@ -77,3 +77,4 @@ docs/openapi.json:
|
|||
- '#/paths/~1tokenize/post'
|
||||
- '#/paths/~1v1~1chat~1completions/post'
|
||||
- '#/paths/~1v1~1completions/post'
|
||||
- '#/paths/~1v1~1models/get'
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -29,6 +29,8 @@ tokenizers = { version = "0.19.1", features = ["http"] }
|
|||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||
metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
minijinja = { version = "2.2.0", features = ["json"] }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
|
33
Dockerfile
33
Dockerfile
|
@ -1,5 +1,5 @@
|
|||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
@ -40,14 +40,14 @@ RUN cargo build --profile release-opt
|
|||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
|
||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
||||
|
||||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||
ARG PYTORCH_VERSION=2.4.0
|
||||
|
||||
ARG PYTHON_VERSION=3.10
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=12.1
|
||||
ARG CUDA_VERSION=12.4
|
||||
ARG MAMBA_VERSION=24.3.0-0
|
||||
ARG CUDA_CHANNEL=nvidia
|
||||
ARG INSTALL_CHANNEL=pytorch
|
||||
|
@ -88,6 +88,7 @@ RUN case ${TARGETPLATFORM} in \
|
|||
FROM pytorch-install AS kernel-builder
|
||||
|
||||
ARG MAX_JOBS=8
|
||||
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX"
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
ninja-build cmake \
|
||||
|
@ -118,29 +119,29 @@ FROM kernel-builder AS exllama-kernels-builder
|
|||
WORKDIR /usr/src
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
RUN python setup.py build
|
||||
|
||||
# Build Transformers exllama kernels
|
||||
FROM kernel-builder AS exllamav2-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllamav2_kernels/ .
|
||||
COPY server/Makefile-exllamav2/ Makefile
|
||||
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
RUN make build-exllamav2
|
||||
|
||||
# Build Transformers awq kernels
|
||||
FROM kernel-builder AS awq-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-awq Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
|
||||
RUN make build-awq
|
||||
|
||||
# Build eetq kernels
|
||||
FROM kernel-builder AS eetq-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-eetq Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
||||
RUN make build-eetq
|
||||
|
||||
# Build Lorax Punica kernels
|
||||
FROM kernel-builder AS lorax-punica-builder
|
||||
|
@ -183,6 +184,12 @@ WORKDIR /usr/src
|
|||
COPY server/Makefile-selective-scan Makefile
|
||||
RUN make build-all
|
||||
|
||||
# Build flashinfer
|
||||
FROM kernel-builder AS flashinfer-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-flashinfer Makefile
|
||||
RUN make install-flashinfer
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
||||
|
||||
|
@ -191,7 +198,7 @@ ENV PATH=/opt/conda/bin:$PATH \
|
|||
CONDA_PREFIX=/opt/conda
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
|
@ -221,11 +228,13 @@ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /
|
|||
# Copy build artifacts from exllama kernels builder
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from exllamav2 kernels builder
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from awq kernels builder
|
||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from eetq kernels builder
|
||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from lorax punica kernels builder
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from fbgemm builder
|
||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from vllm builder
|
||||
|
@ -233,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
|
|||
# Copy build artifacts from mamba builder
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
@ -248,6 +258,9 @@ RUN cd server && \
|
|||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
# This is needed because exl2 tries to load flash-attn
|
||||
# And fails with our builds.
|
||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||
|
||||
# Deps before the binaries
|
||||
# The binaries change on every build given we burn the SHA into them
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Rust builder
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
@ -199,7 +199,7 @@ RUN python setup.py build
|
|||
FROM base AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
ARG PLATFORM=xpu
|
||||
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
@ -57,7 +57,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
|||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
ENV HF_HOME=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
PORT=80
|
||||
|
||||
|
@ -106,7 +106,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
g++ \
|
||||
git \
|
||||
wget \
|
||||
cmake
|
||||
cmake \
|
||||
libnuma-dev
|
||||
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
|
@ -135,7 +136,7 @@ RUN conda install -c conda-forge gperftools mkl
|
|||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install triton
|
||||
RUN pip install triton numa
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -147,16 +148,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update
|
|||
|
||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||
ENV KMP_BLOCKTIME=1
|
||||
ENV KMP_TPAUSE=0
|
||||
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
||||
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
||||
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
@ -175,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV CUDA_GRAPHS=0
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
|
36
README.md
36
README.md
|
@ -13,7 +13,7 @@
|
|||
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
||||
</a>
|
||||
|
||||
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
||||
A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)
|
||||
to power Hugging Chat, the Inference API and Inference Endpoint.
|
||||
|
||||
</div>
|
||||
|
@ -42,12 +42,15 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan
|
|||
- Tensor Parallelism for faster inference on multiple GPUs
|
||||
- Token streaming using Server-Sent Events (SSE)
|
||||
- Continuous batching of incoming requests for increased total throughput
|
||||
- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API
|
||||
- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures
|
||||
- Quantization with :
|
||||
- [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||
- [GPT-Q](https://arxiv.org/abs/2210.17323)
|
||||
- [EETQ](https://github.com/NetEase-FuXi/EETQ)
|
||||
- [AWQ](https://github.com/casper-hansen/AutoAWQ)
|
||||
- [Marlin](https://github.com/IST-DASLab/marlin)
|
||||
- [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)
|
||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||
- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
|
||||
- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))
|
||||
|
@ -92,6 +95,29 @@ curl 127.0.0.1:8080/generate_stream \
|
|||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
|
||||
|
||||
```bash
|
||||
curl localhost:3000/v1/chat/completions \
|
||||
-X POST \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is deep learning?"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 20
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above.
|
||||
|
@ -120,7 +146,7 @@ For example, if you want to serve the gated Llama V2 model variants:
|
|||
or with Docker:
|
||||
|
||||
```shell
|
||||
model=meta-llama/Llama-2-7b-chat-hf
|
||||
model=meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
token=<your cli READ token>
|
||||
|
||||
|
@ -163,6 +189,8 @@ overridden with the `--otlp-service-name` argument
|
|||
|
||||
![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)
|
||||
|
||||
Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
|
||||
|
||||
### Local install
|
||||
|
||||
You can also opt to install `text-generation-inference` locally.
|
||||
|
@ -232,7 +260,7 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2
|
|||
|
||||
### Quantization
|
||||
|
||||
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
||||
You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:
|
||||
|
||||
```shell
|
||||
text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize
|
||||
|
@ -240,6 +268,8 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantiz
|
|||
|
||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||
|
||||
Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).
|
||||
|
||||
## Develop
|
||||
|
||||
```shell
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
mkPoetryApplication,
|
||||
pkg-config,
|
||||
protobuf,
|
||||
openssl,
|
||||
}:
|
||||
|
||||
mkPoetryApplication {
|
||||
# name = "text-generation-server";
|
||||
|
||||
projectDir = ./server;
|
||||
|
||||
# nativeBuildInputs = [ pkg-config ];
|
||||
|
||||
# buildInputs = [ openssl.dev protobuf ];
|
||||
|
||||
}
|
|
@ -153,9 +153,12 @@ impl Client {
|
|||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Most request will have that
|
||||
add_special_tokens: true,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
|
|
@ -221,6 +221,7 @@ impl Health for ShardedClient {
|
|||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
|
@ -244,6 +245,7 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
|
|
@ -8,17 +8,18 @@ homepage.workspace = true
|
|||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
cxx = "1.0"
|
||||
log = { version = "0.4", features = [] }
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.15"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
thiserror = "1.0.62"
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.24"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
log = { version = "0.4", features = [] }
|
||||
parking_lot = "0.12"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
|
|
|
@ -3,7 +3,7 @@ ARG OMPI_VERSION="4.1.6"
|
|||
|
||||
# Build dependencies resolver stage
|
||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
||||
WORKDIR /usr/src/text-generation-inference
|
||||
WORKDIR /usr/src/text-generation-inference/backends/trtllm
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
|
@ -42,7 +42,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE
|
|||
mkdir /usr/src/mpi && \
|
||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
||||
cd /usr/src/mpi && \
|
||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \
|
||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \
|
||||
make -j all && \
|
||||
make install && \
|
||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
||||
|
@ -66,7 +66,7 @@ ENV PATH="/root/.cargo/bin:$PATH"
|
|||
RUN cargo install cargo-chef
|
||||
|
||||
# Cache dependencies
|
||||
COPY --from=planner /usr/src/text-generation-inference/recipe.json .
|
||||
COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json .
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
|
||||
# Build actual TGI
|
||||
|
@ -79,7 +79,8 @@ COPY . .
|
|||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm
|
||||
cd backends/trtllm && \
|
||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
||||
|
||||
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||
WORKDIR /usr/local/tgi/bin
|
||||
|
|
|
@ -12,12 +12,13 @@ use cxx::UniquePtr;
|
|||
use log::{error, warn};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{sleep, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tracing::{instrument, span, Level};
|
||||
|
||||
// use tokio::sync::RwLock;
|
||||
use parking_lot::RwLock;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -160,6 +158,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
messages_api_enabled,
|
||||
true,
|
||||
max_client_batch_size,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -33,9 +33,16 @@ rand = "0.8.5"
|
|||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
|
@ -43,9 +50,11 @@ tracing-opentelemetry = "0.21.0"
|
|||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
|
@ -59,8 +68,16 @@ tower = "^0.4"
|
|||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
itertools = "0.13"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
||||
|
||||
[[bench]]
|
||||
name = "prefix_cache"
|
||||
harness = false
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use rand::Rng;
|
||||
|
||||
use text_generation_router_v3::block_allocator::Allocator;
|
||||
use text_generation_router_v3::radix::RadixAllocator;
|
||||
|
||||
fn prefix_cache_benchmark(c: &mut Criterion) {
|
||||
// let prefixes: Vec<Vec<u32>> = (0..8192)
|
||||
// .chunks(256)
|
||||
// .into_iter()
|
||||
// .map(|c| c.collect())
|
||||
// .collect();
|
||||
|
||||
let mut cache = RadixAllocator::new(1, 262144, None);
|
||||
|
||||
c.bench_function("Radix allocator", |b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
//prefixes
|
||||
// .choose_multiple(&mut rand::thread_rng(), 5)
|
||||
// .fold(Vec::new(), |mut v, s| {
|
||||
// v.extend(s);
|
||||
// v
|
||||
// })
|
||||
|
||||
(0..7936)
|
||||
.map(|_| rand::thread_rng().gen_range(0..1024))
|
||||
.collect::<Vec<u32>>()
|
||||
},
|
||||
|prefill| {
|
||||
let alloc = cache.allocate(
|
||||
prefill.len() as u32 + 13,
|
||||
Some(Arc::new(black_box(prefill))),
|
||||
);
|
||||
if let Some(alloc) = alloc {
|
||||
cache.free(alloc.blocks.clone(), alloc.allocation_id);
|
||||
}
|
||||
},
|
||||
criterion::BatchSize::SmallInput,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, prefix_cache_benchmark);
|
||||
criterion_main!(benches);
|
|
@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
|
|||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
@ -35,16 +35,20 @@ impl BackendV3 {
|
|||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||
|
||||
let attention: Attention = attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||
let block_size = attention.block_size();
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -164,11 +168,14 @@ pub(crate) async fn batching_task(
|
|||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||
// reallocation when using prefix caching.
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
|
|
|
@ -1,21 +1,31 @@
|
|||
use std::cmp::min;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::radix::RadixAllocator;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
pub struct BlockAllocation {
|
||||
pub allocation_id: u64,
|
||||
pub blocks: Vec<u32>,
|
||||
pub slots: Vec<u32>,
|
||||
block_allocator: BlockAllocator,
|
||||
|
||||
/// Prefix that was cached and for which the KV does not have to
|
||||
/// be recomputed.
|
||||
pub prefix_len: u32,
|
||||
|
||||
pub(crate) block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
self.block_allocator.free(self.blocks.clone())
|
||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct BlockAllocator {
|
||||
pub struct BlockAllocator {
|
||||
/// Channel to communicate with the background task
|
||||
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||
}
|
||||
|
@ -24,6 +34,7 @@ impl BlockAllocator {
|
|||
pub(crate) fn new(
|
||||
max_batch_total_tokens: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
|
@ -33,6 +44,7 @@ impl BlockAllocator {
|
|||
tokio::spawn(block_allocator_task(
|
||||
max_batch_total_tokens / block_size,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
receiver,
|
||||
));
|
||||
|
@ -42,28 +54,32 @@ impl BlockAllocator {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||
pub(crate) async fn allocate(
|
||||
&self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let (response_sender, response_receiver) = oneshot::channel();
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
response_receiver
|
||||
.await
|
||||
.unwrap()
|
||||
.map(|(blocks, slots)| BlockAllocation {
|
||||
blocks,
|
||||
slots,
|
||||
block_allocator: self.clone(),
|
||||
})
|
||||
response_receiver.await.unwrap().map(|mut allocation| {
|
||||
allocation.block_allocator = Some(self.clone());
|
||||
allocation
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
self.block_allocator
|
||||
.send(BlockAllocatorCommand::Free { blocks })
|
||||
.send(BlockAllocatorCommand::Free {
|
||||
allocation_id,
|
||||
blocks,
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -71,54 +87,29 @@ impl BlockAllocator {
|
|||
async fn block_allocator_task(
|
||||
blocks: u32,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||
) {
|
||||
// Block 0 is reserved for health checks
|
||||
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||
let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {
|
||||
Box::new(RadixAllocator::new(block_size, blocks, window_size))
|
||||
} else {
|
||||
Box::new(SimpleAllocator::new(blocks, block_size, window_size))
|
||||
};
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||
BlockAllocatorCommand::Free {
|
||||
blocks,
|
||||
allocation_id,
|
||||
} => allocator.free(blocks, allocation_id),
|
||||
BlockAllocatorCommand::Allocate {
|
||||
tokens,
|
||||
prefill_tokens,
|
||||
response_sender,
|
||||
} => {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks =
|
||||
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||
let mut slots = Vec::with_capacity(
|
||||
(required_blocks * block_size * repeats as u32) as usize,
|
||||
);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some((blocks, slots))
|
||||
};
|
||||
response_sender.send(allocation).unwrap();
|
||||
response_sender
|
||||
.send(allocator.allocate(tokens, prefill_tokens))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -128,9 +119,91 @@ async fn block_allocator_task(
|
|||
enum BlockAllocatorCommand {
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
allocation_id: u64,
|
||||
},
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
response_sender: oneshot::Sender<Option<BlockAllocation>>,
|
||||
},
|
||||
}
|
||||
|
||||
pub trait Allocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation>;
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||
}
|
||||
pub struct SimpleAllocator {
|
||||
free_blocks: Vec<u32>,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
}
|
||||
|
||||
impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Allocator for SimpleAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
_prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
// Apply window size
|
||||
let (required_blocks, repeats) = {
|
||||
let (tokens, repeats) = match self.window_size {
|
||||
None => (tokens, 1),
|
||||
Some(window_size) => {
|
||||
let repeats = (tokens + window_size - 1) / window_size;
|
||||
let tokens = core::cmp::min(tokens, window_size);
|
||||
(tokens, repeats as usize)
|
||||
}
|
||||
};
|
||||
// Pad to a multiple of block size
|
||||
let required_blocks = (tokens + self.block_size - 1) / self.block_size;
|
||||
(required_blocks, repeats)
|
||||
};
|
||||
|
||||
let tokens = tokens as usize;
|
||||
if required_blocks > self.free_blocks.len() as u32 {
|
||||
None
|
||||
} else {
|
||||
let blocks = self
|
||||
.free_blocks
|
||||
.split_off(self.free_blocks.len() - required_blocks as usize);
|
||||
let mut slots =
|
||||
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);
|
||||
|
||||
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BlockAllocation {
|
||||
allocation_id: 0,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: 0,
|
||||
block_allocator: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {
|
||||
self.free_blocks.extend(blocks)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -149,6 +149,7 @@ impl Client {
|
|||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
add_special_tokens: true,
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
|
@ -157,6 +158,7 @@ impl Client {
|
|||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
|
|
@ -222,6 +222,7 @@ impl Health for ShardedClient {
|
|||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
add_special_tokens: true,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
|
@ -245,6 +246,7 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
mod backend;
|
||||
mod block_allocator;
|
||||
pub mod block_allocator;
|
||||
mod client;
|
||||
mod queue;
|
||||
pub mod radix;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV3;
|
||||
|
|
|
@ -150,6 +150,14 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
|
|
|
@ -46,6 +46,7 @@ impl Queue {
|
|||
pub(crate) fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
|
@ -57,6 +58,7 @@ impl Queue {
|
|||
tokio::spawn(queue_task(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -109,6 +111,7 @@ impl Queue {
|
|||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
|
@ -117,6 +120,7 @@ async fn queue_task(
|
|||
let mut state = State::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
|
@ -176,12 +180,19 @@ impl State {
|
|||
fn new(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
prefix_caching: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding)
|
||||
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
BlockAllocator::new(
|
||||
max_batch_total_tokens,
|
||||
block_size,
|
||||
prefix_caching,
|
||||
window_size,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
entries: VecDeque::with_capacity(128),
|
||||
|
@ -226,25 +237,29 @@ impl State {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(max_size) = max_size {
|
||||
if max_size == 0 {
|
||||
tracing::debug!("No capacity");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch = Vec::with_capacity(self.entries.len());
|
||||
let mut max_input_length = 0;
|
||||
let mut prefill_tokens: u32 = 0;
|
||||
let mut decode_tokens: u32 = 0;
|
||||
let mut max_blocks = 0;
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
|
@ -258,7 +273,7 @@ impl State {
|
|||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
max_input_length = max_input_length.max(entry.request.input_length);
|
||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
||||
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||
|
||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||
|
@ -272,7 +287,7 @@ impl State {
|
|||
}
|
||||
None
|
||||
}
|
||||
Some(block_allocator) => {
|
||||
Some(_block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
|
@ -298,23 +313,67 @@ impl State {
|
|||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
match block_allocator.allocate(tokens).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
// So no input_ids for the radix tree.
|
||||
let input_ids = if entry.request.decoder_input_details {
|
||||
None
|
||||
} else {
|
||||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
Some((tokens, input_ids))
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for (id, entry, _) in batch.into_iter().rev() {
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation) in batch {
|
||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
|
@ -324,11 +383,12 @@ impl State {
|
|||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
|
||||
let (blocks, slots) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new()),
|
||||
let (blocks, slots, prefix_len) = match &block_allocation {
|
||||
None => (Vec::new(), Vec::new(), 0),
|
||||
Some(block_allocation) => (
|
||||
block_allocation.blocks.clone(),
|
||||
block_allocation.slots.clone(),
|
||||
block_allocation.prefix_len,
|
||||
),
|
||||
};
|
||||
|
||||
|
@ -356,6 +416,7 @@ impl State {
|
|||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
||||
add_special_tokens: entry.request.add_special_tokens,
|
||||
parameters: Some(NextTokenChooserParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
|
@ -365,38 +426,13 @@ impl State {
|
|||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
// Insert in batch_entries IntMap
|
||||
batch_entries.insert(id, entry);
|
||||
|
||||
// Check if max_size
|
||||
if Some(batch_requests.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if our batch is big enough
|
||||
if let Some(min_size) = min_size {
|
||||
// Batch is too small
|
||||
if batch_requests.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for r in batch_requests.into_iter().rev() {
|
||||
let id = r.id;
|
||||
let entry = batch_entries.remove(&id).unwrap();
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
|
@ -473,6 +509,8 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use tracing::info_span;
|
||||
|
||||
|
@ -485,7 +523,9 @@ mod tests {
|
|||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: ValidParameters {
|
||||
|
@ -520,7 +560,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_append() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
|
||||
assert_eq!(state.next_id, 0);
|
||||
|
@ -536,7 +576,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_empty() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
|
||||
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -544,7 +584,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_min_size() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -576,7 +616,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_max_size() {
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let mut state = State::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -596,7 +636,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0, 2);
|
||||
let mut state = State::new(false, 1, false, None, 0, 2);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -629,14 +669,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_append() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _guard) = default_entry();
|
||||
queue.append(entry);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_empty() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
|
||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||
|
@ -644,7 +684,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_min_size() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -677,7 +717,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_max_size() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -693,7 +733,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_budget() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -718,7 +758,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 2, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
@ -737,7 +777,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_dropped_receiver() {
|
||||
let queue = Queue::new(false, 1, None, 0, 16);
|
||||
let queue = Queue::new(false, 1, false, None, 0, 16);
|
||||
let (entry, _) = default_entry();
|
||||
queue.append(entry);
|
||||
|
||||
|
|
|
@ -0,0 +1,850 @@
|
|||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub struct RadixAllocator {
|
||||
allocation_id: u64,
|
||||
|
||||
allocations: HashMap<u64, RadixAllocation>,
|
||||
|
||||
cache_blocks: RadixTrie,
|
||||
|
||||
/// Blocks that are immediately available for allocation.
|
||||
free_blocks: Vec<u32>,
|
||||
|
||||
#[allow(dead_code)]
|
||||
// This isn't used because the prefix need to match without the windowing
|
||||
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||
window_size: Option<u32>,
|
||||
|
||||
block_size: u32,
|
||||
}
|
||||
|
||||
impl RadixAllocator {
|
||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||
RadixAllocator {
|
||||
allocation_id: 0,
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(block_size as usize),
|
||||
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
window_size,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {
|
||||
if self.free_blocks.len() < n_blocks_needed {
|
||||
// This is a bit annoying, we first extend the free list and then
|
||||
// split it off again below. This is because we need to put it on
|
||||
// the free list if we cannot allocate enough blocks. This is only
|
||||
// temporary, the trie needs to be able to report whether it can
|
||||
// allocate the requested amount. Just not implemented yet.
|
||||
self.free_blocks.extend(
|
||||
self.cache_blocks
|
||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||
);
|
||||
}
|
||||
|
||||
if self.free_blocks.len() >= n_blocks_needed {
|
||||
Some(
|
||||
self.free_blocks
|
||||
.split_off(self.free_blocks.len() - n_blocks_needed),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Allocator trait
|
||||
impl Allocator for RadixAllocator {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
tokens: u32,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
};
|
||||
|
||||
// Even if this allocation fails below, we need to increase he
|
||||
// refcount to ensure that the prefix that was found is not evicted.
|
||||
self.cache_blocks
|
||||
.incref(prefix_node)
|
||||
.expect("Failed to increment refcount");
|
||||
|
||||
let prefix_len = blocks.len() * self.block_size as usize;
|
||||
let suffix_len = tokens - prefix_len as u32;
|
||||
|
||||
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||
|
||||
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||
|
||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||
None => {
|
||||
self.cache_blocks
|
||||
.decref(prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// 1:1 mapping of blocks and slots.
|
||||
let slots = if self.block_size == 1 {
|
||||
blocks.clone()
|
||||
} else {
|
||||
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
||||
'slots: for block_id in &blocks {
|
||||
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||
slots.push(s);
|
||||
if slots.len() as u32 == tokens {
|
||||
break 'slots;
|
||||
}
|
||||
}
|
||||
}
|
||||
slots
|
||||
};
|
||||
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
cached_prefix_len: prefix_len,
|
||||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
tracing::debug!("Blocks {blocks:?}");
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
Some(BlockAllocation {
|
||||
allocation_id: self.allocation_id,
|
||||
block_allocator: None,
|
||||
blocks,
|
||||
slots,
|
||||
prefix_len: prefix_len as u32,
|
||||
})
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
|
||||
self.cache_blocks
|
||||
.decref(allocation.prefix_node)
|
||||
.expect("Failed to decrement refcount");
|
||||
|
||||
if let Some(prefill_tokens) = allocation.prefill_tokens {
|
||||
let prefill_tokens = prefill_tokens.as_slice();
|
||||
|
||||
// If there are prefill tokens that did not come from the cache,
|
||||
// add them to the cache.
|
||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||
let aligned =
|
||||
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||
if aligned > 0 {
|
||||
let prefix_len = self
|
||||
.cache_blocks
|
||||
.insert(
|
||||
&prefill_tokens[..aligned],
|
||||
&blocks[..aligned / self.block_size as usize],
|
||||
)
|
||||
// Unwrap, failing is a programming error.
|
||||
.expect("Failed to store prefill tokens");
|
||||
// We can have a prefill with the following structure:
|
||||
//
|
||||
// |---| From the prefix cache.
|
||||
// A B C D E F G
|
||||
//|--------| Found in the trie during insertion.
|
||||
//
|
||||
// This means that while processing this request there was a
|
||||
// partially overlapping request that had A..=E in its
|
||||
// prefill. In this case we need to free the blocks D E.
|
||||
if prefix_len > allocation.cached_prefix_len {
|
||||
self.free_blocks.extend(
|
||||
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||
..prefix_len / self.block_size as usize],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Free non-prefill blocks.
|
||||
self.free_blocks
|
||||
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||
} else {
|
||||
self.free_blocks.extend(blocks);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct RadixAllocation {
|
||||
prefix_node: NodeId,
|
||||
cached_prefix_len: usize,
|
||||
prefill_tokens: Option<Arc<Vec<u32>>>,
|
||||
}
|
||||
|
||||
// Radix trie that is heavily inspired by radix attention from sglang.
|
||||
//
|
||||
// The trie is optimized for prefix caching:
|
||||
//
|
||||
// - A normal radix trie stores discrete values. In this radix trie,
|
||||
// inserting *abc* with value *xyz* will also enable lookup for
|
||||
// *a* (*x*) and *ab* (*xy*).
|
||||
// - As a result, every value is required to have the same length as
|
||||
// the key.
|
||||
// - We store additional information in each node, such as last access
|
||||
// time and a reference count.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TrieError {
|
||||
InvalidNodeId,
|
||||
RefCountUnderflow,
|
||||
BlockTokenCountMismatch,
|
||||
}
|
||||
|
||||
pub type NodeId = DefaultKey;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RadixTrie {
|
||||
/// Identifier of the root nod.
|
||||
root: DefaultKey,
|
||||
|
||||
/// Leave node identifiers ordered by increasing recency.
|
||||
leaves: BTreeSet<(u64, NodeId)>,
|
||||
|
||||
/// All trie nodes.
|
||||
nodes: SlotMap<NodeId, TrieNode>,
|
||||
|
||||
/// Time as a monotonically increating counter to avoid the system
|
||||
/// call that a real time lookup would require.
|
||||
time: u64,
|
||||
|
||||
/// All blocks need to be aligned with this
|
||||
block_size: usize,
|
||||
}
|
||||
|
||||
impl RadixTrie {
|
||||
/// Construct a new radix trie.
|
||||
pub fn new(block_size: usize) -> Self {
|
||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||
let mut nodes = SlotMap::new();
|
||||
let root = nodes.insert(root);
|
||||
RadixTrie {
|
||||
leaves: BTreeSet::new(),
|
||||
nodes,
|
||||
root,
|
||||
time: 0,
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the prefix of the given tokens.
|
||||
///
|
||||
/// The blocks corresponding to the part of the prefix that could be found
|
||||
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||
/// Returns the identifier of the trie node that contains the longest
|
||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||
/// reference count.
|
||||
///
|
||||
/// Using this method will update the access time of the traversed nodes.
|
||||
pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
self.time += 1;
|
||||
self.find_(self.root, key, blocks)
|
||||
}
|
||||
|
||||
/// Find worker.
|
||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
let node = &self.nodes[node_id];
|
||||
|
||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
node_id = self.find_(child_id, key, blocks);
|
||||
}
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
/// Decrease the reference count of a node.
|
||||
pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
// We don't care about refcounting for root, since it will never
|
||||
// be evicted.
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
return Err(TrieError::RefCountUnderflow);
|
||||
}
|
||||
|
||||
node.ref_count -= 1;
|
||||
if node.ref_count == 0 {
|
||||
assert!(
|
||||
node.children.is_empty(),
|
||||
"Nodes with children must have refcount > 0"
|
||||
);
|
||||
|
||||
self.leaves.insert((node.last_accessed, node_id));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Increase the reference count of a node.
|
||||
pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {
|
||||
if node_id == self.root {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.ok_or(TrieError::InvalidNodeId)?;
|
||||
if node.ref_count == 0 {
|
||||
self.leaves.remove(&(node.last_accessed, node_id));
|
||||
}
|
||||
node.ref_count += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Evict `n_blocks` from the trie.
|
||||
///
|
||||
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||
/// not enough blocks could be evicted.
|
||||
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||
// it's a programming error in the trie implementation, not a user
|
||||
// error caused by e.g. an invalid argument.
|
||||
|
||||
// TODO: add some bookkeeping in the future to check whether we can
|
||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||
// evicting prefixes from the cache in such a case.
|
||||
let mut evicted = Vec::new();
|
||||
|
||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||
let blocks_needed = n_blocks - evicted.len();
|
||||
|
||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||
assert_eq!(
|
||||
node.ref_count, 0,
|
||||
"Leaf must have refcount of 0, got {}",
|
||||
node.ref_count
|
||||
);
|
||||
|
||||
if blocks_needed >= node.blocks.len() {
|
||||
// We need to evict the whole node if we need more blocks than it has.
|
||||
let node = self.remove_node(node_id);
|
||||
evicted.extend(node.blocks);
|
||||
|
||||
if evicted.len() >= n_blocks {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The node has more blocks than needed, so we'll just remove
|
||||
// the required number of blocks and leave the remaining blocks
|
||||
// untouched.
|
||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||
node.key.truncate(node.blocks.len() - blocks_needed);
|
||||
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
|
||||
self.leaves.insert((last_access, node_id));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
evicted
|
||||
}
|
||||
|
||||
/// Insert a prefill along with its blocks.
|
||||
///
|
||||
/// This method returns the length of the prefix that was already
|
||||
/// in the trie. E.g. if the length is 10, this means that for
|
||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||
self.time += 1;
|
||||
let common = self.insert_(self.root, tokens, blocks)?;
|
||||
Ok(common)
|
||||
}
|
||||
|
||||
/// Insertion worker.
|
||||
fn insert_(
|
||||
&mut self,
|
||||
node_id: NodeId,
|
||||
tokens: &[u32],
|
||||
blocks: &[u32],
|
||||
) -> Result<usize, TrieError> {
|
||||
// TODO: in the future we may want to check that the blocks match for
|
||||
// the part of the prefix that is already in the trie to detect
|
||||
// mismatches.
|
||||
|
||||
if tokens.len() != blocks.len() * self.block_size {
|
||||
return Err(TrieError::BlockTokenCountMismatch);
|
||||
}
|
||||
|
||||
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
|
||||
self.update_access_time(child_id);
|
||||
let child = self
|
||||
.nodes
|
||||
.get_mut(child_id)
|
||||
// Unwrap here, since failure is a bug.
|
||||
.expect("Child node does not exist");
|
||||
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
||||
|
||||
// We are done, the prefix is already in the trie.
|
||||
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
||||
return Ok(shared_prefix_len);
|
||||
}
|
||||
|
||||
// The node's prefix is a prefix of the insertion prefix.
|
||||
if shared_prefix_len == child.key.len() {
|
||||
return Ok(shared_prefix_len
|
||||
+ self.insert_(
|
||||
child_id,
|
||||
&tokens[shared_prefix_len..],
|
||||
&blocks[shared_prefix_len / self.block_size..],
|
||||
)?);
|
||||
}
|
||||
|
||||
// The node's prefix and the insertion prefix only match partially,
|
||||
// split the node to just contain the matching part. Then insert the
|
||||
// remainder of the prefix into the node again
|
||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||
let key = &tokens[shared_prefix_len..];
|
||||
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||
} else {
|
||||
self.add_node(node_id, tokens, blocks);
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {
|
||||
// We have to make the current node a child to ensure that its
|
||||
// properties and node id stay the same.
|
||||
|
||||
// This funcion unwraps, an invalid node_id is a programming error.
|
||||
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
let mut parent_key = node.key.split_off(prefix_len);
|
||||
let mut parent_blocks = node.blocks.split_off(prefix_len);
|
||||
|
||||
// Move first part of the prefix to the parent. We swap to avoid
|
||||
// an allocation + copy for both splits of the key/blocks.
|
||||
std::mem::swap(&mut node.key, &mut parent_key);
|
||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||
|
||||
let node_key = node.key[0];
|
||||
|
||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||
self.add_node_to_parent(parent_id, node_key, node_id);
|
||||
|
||||
// Reborrow to make the borrow checker happy.
|
||||
let node = self
|
||||
.nodes
|
||||
.get_mut(node_id)
|
||||
.expect("Node to-be split does not exist");
|
||||
node.parent = Some(parent_id);
|
||||
|
||||
parent_id
|
||||
}
|
||||
|
||||
/// Create a node and add it to the parent.
|
||||
fn add_node(
|
||||
&mut self,
|
||||
parent_id: NodeId,
|
||||
key: impl Into<Vec<u32>>,
|
||||
blocks: impl Into<Vec<u32>>,
|
||||
) -> NodeId {
|
||||
let key = key.into();
|
||||
let blocks = blocks.into();
|
||||
let first = key[0];
|
||||
|
||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||
let child_id = self.nodes.insert(child);
|
||||
|
||||
self.add_node_to_parent(parent_id, first, child_id);
|
||||
self.leaves.insert((self.time, child_id));
|
||||
|
||||
child_id
|
||||
}
|
||||
|
||||
/// Add a node to the parent.
|
||||
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
if parent.children.insert(first, child_id).is_none() {
|
||||
// Only increase reference count if child does not replace another child.
|
||||
self.incref(parent_id)
|
||||
.expect("Failed to increase parent refcount");
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a node from the trie.
|
||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||
assert!(
|
||||
node.children.is_empty(),
|
||||
"Tried to remove a node with {} children",
|
||||
node.children.len()
|
||||
);
|
||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||
parent.children.remove(&node.key[0]);
|
||||
self.decref(parent_id)
|
||||
.expect("Failed to decrease parent refcount");
|
||||
node
|
||||
}
|
||||
|
||||
fn update_access_time(&mut self, node_id: NodeId) {
|
||||
// Unwrap here, passing in an unknown id is a programming error.
|
||||
let node = self.nodes.get_mut(node_id).expect("Unknown node");
|
||||
|
||||
// Update the ordered leaves set if the node is a leave.
|
||||
if self.leaves.remove(&(node.last_accessed, node_id)) {
|
||||
self.leaves.insert((self.time, node_id));
|
||||
}
|
||||
|
||||
node.last_accessed = self.time;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[doc(hidden)]
|
||||
/// Print debugging output for the trie.
|
||||
///
|
||||
/// In contrast to `Debug` nicely formatted.
|
||||
pub fn print_debug(&self) {
|
||||
self.print_debug_(self.root, 0);
|
||||
}
|
||||
|
||||
fn print_debug_(&self, node_id: NodeId, indent: usize) {
|
||||
let node = &self.nodes[node_id];
|
||||
eprintln!(
|
||||
"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}",
|
||||
" ".repeat(indent),
|
||||
node_id,
|
||||
node.key,
|
||||
node.blocks,
|
||||
node.ref_count,
|
||||
node.last_accessed,
|
||||
node.parent,
|
||||
node.children
|
||||
);
|
||||
for child_id in self.nodes[node_id].children.values() {
|
||||
self.print_debug_(*child_id, indent + 2);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn root_id(&self) -> DefaultKey {
|
||||
self.root
|
||||
}
|
||||
}
|
||||
|
||||
/// Trie node.
|
||||
#[derive(Debug)]
|
||||
struct TrieNode {
|
||||
blocks: Vec<u32>,
|
||||
children: HashMap<u32, NodeId>,
|
||||
key: Vec<u32>,
|
||||
last_accessed: u64,
|
||||
parent: Option<NodeId>,
|
||||
ref_count: usize,
|
||||
}
|
||||
|
||||
impl TrieNode {
|
||||
fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {
|
||||
TrieNode {
|
||||
children: HashMap::new(),
|
||||
key,
|
||||
blocks,
|
||||
last_accessed,
|
||||
parent,
|
||||
ref_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
||||
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
||||
// NOTE: this is the case because the child node was chosen based on
|
||||
// matching the first character of the key/prefix.
|
||||
assert!(full > 0, "Prefixes must at least share 1 token");
|
||||
(full / block_size) * block_size
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_block_size_non_aligned() {
|
||||
let mut cache = RadixAllocator::new(2, 12, None);
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_reuses_prefixes() {
|
||||
let mut cache = RadixAllocator::new(1, 12, None);
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.blocks, allocation.slots);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_collects_older_prefixes_first() {
|
||||
let mut cache = RadixAllocator::new(1, 7, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![1, 2]);
|
||||
assert_eq!(allocation2.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// We should get the blocks of the first allocation, since they are more recent.
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);
|
||||
assert_eq!(allocation3.prefix_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_fully_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 10, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation3.prefix_len, 4);
|
||||
|
||||
// 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.
|
||||
assert_eq!(cache.free_blocks.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocator_frees_partially_overlapping_prefills() {
|
||||
let mut cache = RadixAllocator::new(1, 20, None);
|
||||
let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();
|
||||
assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);
|
||||
assert_eq!(allocation1.prefix_len, 0);
|
||||
|
||||
cache.free(allocation1.blocks.clone(), allocation1.allocation_id);
|
||||
|
||||
let allocation2 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);
|
||||
assert_eq!(allocation2.prefix_len, 2);
|
||||
|
||||
let allocation3 = cache
|
||||
.allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);
|
||||
assert_eq!(allocation3.prefix_len, 2);
|
||||
|
||||
cache.free(allocation3.blocks.clone(), allocation3.allocation_id);
|
||||
cache.free(allocation2.blocks.clone(), allocation2.allocation_id);
|
||||
|
||||
// 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation4 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);
|
||||
assert_eq!(allocation4.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
|
||||
let allocation5 = cache
|
||||
.allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))
|
||||
.unwrap();
|
||||
assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);
|
||||
assert_eq!(allocation5.prefix_len, 6);
|
||||
assert_eq!(cache.free_blocks.len(), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_have_correct_prefix_len() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap(),
|
||||
4
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_insertions_block_size() {
|
||||
let mut trie = RadixTrie::new(2);
|
||||
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
||||
|
||||
// Already exists.
|
||||
// But needs to be block_size aligned
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
||||
|
||||
// Completely new at root-level
|
||||
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
||||
|
||||
// Contains full prefix, but longer.
|
||||
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
||||
|
||||
// Shares partial prefix, we need a split.
|
||||
assert_eq!(
|
||||
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
||||
.unwrap(),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_get_returns_correct_blocks() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
trie.find(&[0], &mut blocks);
|
||||
assert_eq!(blocks, vec![0]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 4]);
|
||||
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trie_evict_removes_correct_blocks() {
|
||||
let mut trie = RadixTrie::new(1);
|
||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||
.unwrap();
|
||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
|
||||
// Remove less than the leave blocks.
|
||||
assert_eq!(trie.evict(1), vec![7]);
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);
|
||||
|
||||
// Refresh other leaf.
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove the leave blocks exactly.
|
||||
assert_eq!(trie.evict(2), vec![5, 6]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1, 2, 3]);
|
||||
|
||||
trie.find(&[1, 2, 3], &mut blocks);
|
||||
|
||||
// Remove more than the leave blocks.
|
||||
assert_eq!(trie.evict(3), vec![4, 3, 2]);
|
||||
blocks.clear();
|
||||
trie.find(&[0, 1, 2, 3, 4], &mut blocks);
|
||||
assert_eq!(blocks, vec![0, 1]);
|
||||
|
||||
// Clear out the whole trie.
|
||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||
}
|
||||
}
|
|
@ -148,6 +148,7 @@ async fn prefill(
|
|||
}),
|
||||
inputs: sequence.clone(),
|
||||
truncate: sequence_length,
|
||||
add_special_tokens: true,
|
||||
parameters: Some(parameters.clone()),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: decode_length,
|
||||
|
@ -157,6 +158,7 @@ async fn prefill(
|
|||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
|
|
@ -757,7 +757,12 @@ class AsyncClient:
|
|||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
payload_data = (
|
||||
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||
)
|
||||
if payload_data == "[DONE]":
|
||||
break
|
||||
json_payload = json.loads(payload_data)
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
|
|
|
@ -556,6 +556,37 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/models": {
|
||||
"get": {
|
||||
"tags": [
|
||||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Get model info",
|
||||
"operationId": "openai_get_model_info",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Served model info",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ModelInfo"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {
|
||||
"description": "Model not found",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/ErrorResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"components": {
|
||||
|
@ -819,6 +850,13 @@
|
|||
"example": "1.0",
|
||||
"nullable": true
|
||||
},
|
||||
"guideline": {
|
||||
"type": "string",
|
||||
"description": "A guideline to be used in the chat_template",
|
||||
"default": "null",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"logit_bias": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
|
@ -917,7 +955,7 @@
|
|||
"tool_prompt": {
|
||||
"type": "string",
|
||||
"description": "A prompt to be appended before the tools",
|
||||
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
||||
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
|
||||
"nullable": true
|
||||
},
|
||||
"tools": {
|
||||
|
@ -1740,6 +1778,35 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
"ModelInfo": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"object",
|
||||
"created",
|
||||
"owned_by"
|
||||
],
|
||||
"properties": {
|
||||
"created": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": 1686935002,
|
||||
"minimum": 0
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"example": "gpt2"
|
||||
},
|
||||
"object": {
|
||||
"type": "string",
|
||||
"example": "model"
|
||||
},
|
||||
"owned_by": {
|
||||
"type": "string",
|
||||
"example": "openai"
|
||||
}
|
||||
}
|
||||
},
|
||||
"OutputMessage": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
@ -1817,7 +1884,8 @@
|
|||
"type": "object",
|
||||
"required": [
|
||||
"finish_reason",
|
||||
"generated_tokens"
|
||||
"generated_tokens",
|
||||
"input_length"
|
||||
],
|
||||
"properties": {
|
||||
"finish_reason": {
|
||||
|
@ -1829,6 +1897,12 @@
|
|||
"example": 1,
|
||||
"minimum": 0
|
||||
},
|
||||
"input_length": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": 1,
|
||||
"minimum": 0
|
||||
},
|
||||
"seed": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
title: Installation from source
|
||||
- local: supported_models
|
||||
title: Supported Models and Hardware
|
||||
- local: messages_api
|
||||
title: Messages API
|
||||
- local: architecture
|
||||
title: Internal Architecture
|
||||
- local: usage_statistics
|
||||
|
@ -33,8 +31,6 @@
|
|||
title: Serving Private & Gated Models
|
||||
- local: basic_tutorials/using_cli
|
||||
title: Using TGI CLI
|
||||
- local: basic_tutorials/launcher
|
||||
title: All TGI CLI options
|
||||
- local: basic_tutorials/non_core_models
|
||||
title: Non-core Model Serving
|
||||
- local: basic_tutorials/safety
|
||||
|
@ -48,6 +44,14 @@
|
|||
- local: basic_tutorials/train_medusa
|
||||
title: Train Medusa
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- local: reference/launcher
|
||||
title: All TGI CLI options
|
||||
- local: reference/metrics
|
||||
title: Exported Metrics
|
||||
- local: reference/api_reference
|
||||
title: API Reference
|
||||
title: Reference
|
||||
- sections:
|
||||
- local: conceptual/streaming
|
||||
title: Streaming
|
||||
|
@ -64,9 +68,11 @@
|
|||
- local: conceptual/speculation
|
||||
title: Speculation (Medusa, ngram)
|
||||
- local: conceptual/guidance
|
||||
title: How Guidance Works (via outlines
|
||||
title: How Guidance Works (via outlines)
|
||||
- local: conceptual/lora
|
||||
title: LoRA (Low-Rank Adaptation)
|
||||
- local: conceptual/external
|
||||
title: External Resources
|
||||
|
||||
|
||||
title: Conceptual Guides
|
||||
|
|
|
@ -1,81 +1,125 @@
|
|||
# Consuming Text Generation Inference
|
||||
|
||||
There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models.
|
||||
There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens.
|
||||
|
||||
For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference).
|
||||
|
||||
You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models.
|
||||
|
||||
## curl
|
||||
|
||||
After the launch, you can query the model using either the `/generate` or `/generate_stream` routes:
|
||||
After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec:
|
||||
|
||||
```bash
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-X POST \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is deep learning?"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 20
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes.
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8080/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||
-d '{
|
||||
"inputs":"What is Deep Learning?",
|
||||
"parameters":{
|
||||
"max_new_tokens":20
|
||||
}
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
## Python
|
||||
|
||||
## Inference Client
|
||||
### Inference Client
|
||||
|
||||
[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface.
|
||||
You can simply install `huggingface-hub` package with pip.
|
||||
[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface.
|
||||
|
||||
Install `huggingface_hub` package via pip.
|
||||
|
||||
```bash
|
||||
pip install huggingface-hub
|
||||
pip install huggingface_hub
|
||||
```
|
||||
|
||||
Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python.
|
||||
You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python
|
||||
|
||||
```python
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient(model="http://127.0.0.1:8080")
|
||||
client.text_generation(prompt="Write a code for snake game")
|
||||
client = InferenceClient(
|
||||
base_url="http://localhost:8080/v1/",
|
||||
)
|
||||
|
||||
output = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count to 10"},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
for chunk in output:
|
||||
print(chunk.choices[0].delta.content)
|
||||
```
|
||||
|
||||
You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows:
|
||||
You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility).
|
||||
|
||||
There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
|
||||
|
||||
### OpenAI Client
|
||||
|
||||
You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI.
|
||||
|
||||
Install the OpenAI Python package via pip.
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
```python
|
||||
for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True):
|
||||
print(token)
|
||||
from openai import OpenAI
|
||||
|
||||
# init the client but point it to TGI
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1/",
|
||||
api_key="-"
|
||||
)
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant." },
|
||||
{"role": "user", "content": "What is deep learning?"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
# iterate and print stream
|
||||
for message in chat_completion:
|
||||
print(message)
|
||||
```
|
||||
|
||||
Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream.
|
||||
## UI
|
||||
|
||||
```python
|
||||
output = client.text_generation(prompt="Meaning of life is", details=True)
|
||||
print(output)
|
||||
|
||||
# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..))
|
||||
```
|
||||
|
||||
You can see how to stream below.
|
||||
|
||||
```python
|
||||
output = client.text_generation(prompt="Meaning of life is", stream=True, details=True)
|
||||
print(next(iter(output)))
|
||||
|
||||
# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None)
|
||||
```
|
||||
|
||||
You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)
|
||||
|
||||
|
||||
## ChatUI
|
||||
|
||||
ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.
|
||||
|
||||
To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.
|
||||
|
||||
```
|
||||
{
|
||||
// rest of the model config here
|
||||
"endpoints": [{"url": "https://HOST:PORT/generate_stream"}]
|
||||
}
|
||||
```
|
||||
|
||||
![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)
|
||||
|
||||
## Gradio
|
||||
### Gradio
|
||||
|
||||
Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.
|
||||
|
||||
|
@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference
|
|||
import gradio as gr
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient(model="http://127.0.0.1:8080")
|
||||
client = InferenceClient(base_url="http://127.0.0.1:8080")
|
||||
|
||||
def inference(message, history):
|
||||
partial_message = ""
|
||||
for token in client.text_generation(message, max_new_tokens=20, stream=True):
|
||||
partial_message += token
|
||||
output = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
for chunk in output:
|
||||
partial_message += chunk.choices[0].delta.content
|
||||
yield partial_message
|
||||
|
||||
gr.ChatInterface(
|
||||
inference,
|
||||
chatbot=gr.Chatbot(height=300),
|
||||
textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7),
|
||||
description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.",
|
||||
description="This is the demo for Gradio UI consuming TGI endpoint.",
|
||||
title="Gradio 🤝 TGI",
|
||||
examples=["Are tomatoes vegetables?"],
|
||||
retry_btn="Retry",
|
||||
|
@ -110,20 +163,7 @@ gr.ChatInterface(
|
|||
).queue().launch()
|
||||
```
|
||||
|
||||
The UI looks like this 👇
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img
|
||||
class="block dark:hidden"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi.png"
|
||||
/>
|
||||
<img
|
||||
class="hidden dark:block"
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/gradio-tgi-dark.png"
|
||||
/>
|
||||
</div>
|
||||
|
||||
You can try the demo directly here 👇
|
||||
You can check out the UI and try the demo directly here 👇
|
||||
|
||||
<div class="block dark:hidden">
|
||||
<iframe
|
||||
|
@ -141,15 +181,19 @@ You can try the demo directly here 👇
|
|||
</div>
|
||||
|
||||
|
||||
You can disable streaming mode using `return` instead of `yield` in your inference function, like below.
|
||||
|
||||
```python
|
||||
def inference(message, history):
|
||||
return client.text_generation(message, max_new_tokens=20)
|
||||
```
|
||||
|
||||
You can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).
|
||||
|
||||
## API documentation
|
||||
### ChatUI
|
||||
|
||||
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. The Swagger UI is also available [here](https://huggingface.github.io/text-generation-inference).
|
||||
[ChatUI](https://github.com/huggingface/chat-ui) is an open-source interface built for consuming LLMs. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.
|
||||
|
||||
To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.
|
||||
|
||||
```
|
||||
{
|
||||
// rest of the model config here
|
||||
"endpoints": [{"url": "https://HOST:PORT/generate_stream"}]
|
||||
}
|
||||
```
|
||||
|
||||
![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)
|
||||
|
|
|
@ -4,7 +4,7 @@ Text Generation Inference improves the model in several aspects.
|
|||
|
||||
## Quantization
|
||||
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323) and [AWQ](https://arxiv.org/abs/2306.00978) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq` or `awq` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq) when using AWQ quantization, you need to point to one of the models [here](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
|
||||
TGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)
|
||||
|
||||
|
||||
## RoPE Scaling
|
||||
|
|
|
@ -4,7 +4,7 @@ Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-
|
|||
|
||||
These feature are available starting from version `1.4.3`. They are accessible via the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
|
||||
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `/chat/completions` endpoint._
|
||||
_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `v1/chat/completions` endpoint._
|
||||
|
||||
## How it works
|
||||
|
||||
|
@ -157,7 +157,12 @@ from huggingface_hub import InferenceClient
|
|||
|
||||
client = InferenceClient("http://localhost:3000")
|
||||
|
||||
regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)"
|
||||
section_regex = "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||
regexp = f"HELLO\.{section_regex}\.WORLD\.{section_regex}"
|
||||
|
||||
# This is a more realistic example of an ip address regex
|
||||
# regexp = f"{section_regex}\.{section_regex}\.{section_regex}\.{section_regex}"
|
||||
|
||||
|
||||
resp = client.text_generation(
|
||||
f"Whats Googles DNS? Please use the following regex: {regexp}",
|
||||
|
@ -170,7 +175,7 @@ resp = client.text_generation(
|
|||
|
||||
|
||||
print(resp)
|
||||
# 7.1.1.1
|
||||
# HELLO.255.WORLD.255
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ print(chat)
|
|||
|
||||
```
|
||||
|
||||
or with OpenAi's library:
|
||||
or with OpenAI's [client library](https://github.com/openai/openai-python):
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# External Resources
|
||||
|
||||
- Adyen wrote a detailed article about the interplay between TGI's main components: router and server.
|
||||
[LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
|
|
@ -36,6 +36,18 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
|||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
||||
```
|
||||
|
||||
note it's possible to mix adapter_ids with adapter_id=adapter_path e.g.
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/
|
||||
```
|
||||
|
||||
In the server logs, you will see the following message:
|
||||
|
||||
```txt
|
||||
|
|
|
@ -1,6 +1,40 @@
|
|||
# Quantization
|
||||
|
||||
TGI offers GPTQ and bits-and-bytes quantization to quantize large language models.
|
||||
TGI offers many quantization schemes to run LLMs effectively and fast based on your use-case. TGI supports GPTQ, AWQ, bits-and-bytes, EETQ, Marlin, EXL2 and fp8 quantization.
|
||||
|
||||
To leverage GPTQ, AWQ, Marlin and EXL2 quants, you must provide pre-quantized weights. Whereas for bits-and-bytes, EETQ and fp8, weights are quantized by TGI on the fly.
|
||||
|
||||
We recommend using the official quantization scripts for creating your quants:
|
||||
1. [AWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quantize.py)
|
||||
2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py)
|
||||
3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md)
|
||||
|
||||
For on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.
|
||||
|
||||
## Quantization with bitsandbytes, EETQ & fp8
|
||||
|
||||
bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.
|
||||
|
||||
8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
|
||||
Similarly you can use pass you can pass `--quantize eetq` or `--quantize fp8` for respective quantization schemes.
|
||||
|
||||
In addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset.
|
||||
|
||||
## Quantization with GPTQ
|
||||
|
||||
|
@ -36,24 +70,3 @@ You can learn more about the quantization options by running `text-generation-se
|
|||
|
||||
If you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration).
|
||||
You can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).
|
||||
|
||||
## Quantization with bitsandbytes
|
||||
|
||||
bitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.
|
||||
|
||||
8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.
|
||||
In TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes
|
||||
```
|
||||
|
||||
4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.
|
||||
|
||||
In TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇
|
||||
|
||||
```bash
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --quantize bitsandbytes-nf4
|
||||
```
|
||||
|
||||
You can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Streaming
|
||||
|
||||
|
||||
## What is Streaming?
|
||||
|
||||
Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience.
|
||||
|
@ -48,34 +49,29 @@ To stream tokens with `InferenceClient`, simply pass `stream=True` and iterate o
|
|||
```python
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
client = InferenceClient("http://127.0.0.1:8080")
|
||||
for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True):
|
||||
print(token)
|
||||
client = InferenceClient(base_url="http://127.0.0.1:8080")
|
||||
output = client.chat.completions.create(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Count to 10"},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
# To
|
||||
# make
|
||||
# cheese
|
||||
#,
|
||||
# you
|
||||
# need
|
||||
# to
|
||||
# start
|
||||
# with
|
||||
# milk
|
||||
#.
|
||||
```
|
||||
for chunk in output:
|
||||
print(chunk.choices[0].delta.content)
|
||||
|
||||
If you want additional details, you can add `details=True`. In this case, you get a `TextGenerationStreamResponse` which contains additional information such as the probabilities and the tokens. For the final response in the stream, it also returns the full generated text.
|
||||
|
||||
```python
|
||||
for details in client.text_generation("How do you make cheese?", max_new_tokens=12, details=True, stream=True):
|
||||
print(details)
|
||||
|
||||
#TextGenerationStreamResponse(token=Token(id=193, text='\n', logprob=-0.007358551, special=False), generated_text=None, details=None)
|
||||
#TextGenerationStreamResponse(token=Token(id=2044, text='To', logprob=-1.1357422, special=False), generated_text=None, details=None)
|
||||
#TextGenerationStreamResponse(token=Token(id=717, text=' make', logprob=-0.009841919, special=False), generated_text=None, details=None)
|
||||
#...
|
||||
#TextGenerationStreamResponse(token=Token(id=25, text='.', logprob=-1.3408203, special=False), generated_text='\nTo make cheese, you need to start with milk.', details=StreamDetails(finish_reason=<FinishReason.Length: 'length'>, generated_tokens=12, seed=None))
|
||||
# 1
|
||||
# 2
|
||||
# 3
|
||||
# 4
|
||||
# 5
|
||||
# 6
|
||||
# 7
|
||||
# 8
|
||||
# 9
|
||||
# 10
|
||||
```
|
||||
|
||||
The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case you need to handle the requests concurrently.
|
||||
|
@ -83,31 +79,46 @@ The `huggingface_hub` library also comes with an `AsyncInferenceClient` in case
|
|||
```python
|
||||
from huggingface_hub import AsyncInferenceClient
|
||||
|
||||
client = AsyncInferenceClient("http://127.0.0.1:8080")
|
||||
async for token in await client.text_generation("How do you make cheese?", stream=True):
|
||||
print(token)
|
||||
client = AsyncInferenceClient(base_url="http://127.0.0.1:8080")
|
||||
async def main():
|
||||
stream = await client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "Say this is a test"}],
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
print(chunk.choices[0].delta.content or "", end="")
|
||||
|
||||
# To
|
||||
# make
|
||||
# cheese
|
||||
#,
|
||||
# you
|
||||
# need
|
||||
# to
|
||||
# start
|
||||
# with
|
||||
# milk
|
||||
asyncio.run(main())
|
||||
|
||||
# This
|
||||
# is
|
||||
# a
|
||||
# test
|
||||
#.
|
||||
```
|
||||
|
||||
### Streaming with cURL
|
||||
|
||||
To use the `generate_stream` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server
|
||||
To use the OpenAI Chat Completions compatible Messages API `v1/chat/completions` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server
|
||||
|
||||
```curl
|
||||
curl -N 127.0.0.1:8080/generate_stream \
|
||||
curl localhost:8080/v1/chat/completions \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is deep learning?"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 20
|
||||
}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
|
|
|
@ -12,7 +12,24 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel \
|
||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
# Using TGI with Intel CPUs
|
||||
|
||||
Intel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more.
|
||||
|
||||
On a server powered by Intel CPU, TGI can be launched with the following command:
|
||||
|
||||
```bash
|
||||
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPU
|
|||
|
||||
## Consuming TGI
|
||||
|
||||
Once TGI is running, you can use the `generate` endpoint by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||
Once TGI is running, you can use the `generate` endpoint or the Open AI Chat Completion API compatible [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.
|
||||
|
||||
<inferencesnippet>
|
||||
<python>
|
||||
|
|
|
@ -1,18 +1,31 @@
|
|||
# Messages API
|
||||
# HTTP API Reference
|
||||
|
||||
#### Table of Contents
|
||||
|
||||
- [Text Generation Inference custom API](#text-generation-inference-custom-api)
|
||||
- [OpenAI Messages API](#openai-messages-api)
|
||||
- [Making a Request](#making-a-request)
|
||||
- [Streaming](#streaming)
|
||||
- [Synchronous](#synchronous)
|
||||
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
|
||||
- [Cloud Providers](#cloud-providers)
|
||||
- [Amazon SageMaker](#amazon-sagemaker)
|
||||
|
||||
The HTTP API is a RESTful API that allows you to interact with the text-generation-inference component. Two endpoints are available:
|
||||
* Text Generation Inference [custom API](https://huggingface.github.io/text-generation-inference/)
|
||||
* OpenAI's [Messages API](#openai-messages-api)
|
||||
|
||||
|
||||
## Text Generation Inference custom API
|
||||
|
||||
Check the [API documentation](https://huggingface.github.io/text-generation-inference/) for more information on how to interact with the Text Generation Inference API.
|
||||
|
||||
## OpenAI Messages API
|
||||
|
||||
Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
|
||||
|
||||
> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.
|
||||
|
||||
#### Table of Contents
|
||||
|
||||
- [Making a Request](#making-a-request)
|
||||
- [Streaming](#streaming)
|
||||
- [Synchronous](#synchronous)
|
||||
- [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)
|
||||
- [Cloud Providers](#cloud-providers)
|
||||
- [Amazon SageMaker](#amazon-sagemaker)
|
||||
|
||||
## Making a Request
|
||||
|
||||
You can make a request to TGI's Messages API using `curl`. Here's an example:
|
|
@ -0,0 +1,30 @@
|
|||
# Metrics
|
||||
|
||||
TGI exposes multiple metrics that can be collected via the `/metrics` Prometheus endpoint.
|
||||
These metrics can be used to monitor the performance of TGI, autoscale deployment and to help identify bottlenecks.
|
||||
|
||||
The following metrics are exposed:
|
||||
|
||||
| Metric Name | Description | Type | Unit |
|
||||
|--------------------------------------------|------------------------------------------------------------------------------------------|-----------|---------|
|
||||
| `tgi_batch_current_max_tokens` | Maximum tokens for the current batch | Gauge | Count |
|
||||
| `tgi_batch_current_size` | Current batch size | Gauge | Count |
|
||||
| `tgi_batch_decode_duration` | Time spent decoding a batch per method (prefill or decode) | Histogram | Seconds |
|
||||
| `tgi_batch_filter_duration` | Time spent filtering batches and sending generated tokens per method (prefill or decode) | Histogram | Seconds |
|
||||
| `tgi_batch_forward_duration` | Batch forward duration per method (prefill or decode) | Histogram | Seconds |
|
||||
| `tgi_batch_inference_count` | Inference calls per method (prefill or decode) | Counter | Count |
|
||||
| `tgi_batch_inference_duration` | Batch inference duration | Histogram | Seconds |
|
||||
| `tgi_batch_inference_success` | Number of successful inference calls per method (prefill or decode) | Counter | Count |
|
||||
| `tgi_batch_next_size` | Batch size of the next batch | Histogram | Count |
|
||||
| `tgi_queue_size` | Current queue size | Gauge | Count |
|
||||
| `tgi_request_count` | Total number of requests | Counter | Count |
|
||||
| `tgi_request_duration` | Total time spent processing the request (e2e latency) | Histogram | Seconds |
|
||||
| `tgi_request_generated_tokens` | Generated tokens per request | Histogram | Count |
|
||||
| `tgi_request_inference_duration` | Request inference duration | Histogram | Seconds |
|
||||
| `tgi_request_input_length` | Input token length per request | Histogram | Count |
|
||||
| `tgi_request_max_new_tokens` | Maximum new tokens per request | Histogram | Count |
|
||||
| `tgi_request_mean_time_per_token_duration` | Mean time per token per request (inter-token latency) | Histogram | Seconds |
|
||||
| `tgi_request_queue_duration` | Time spent in the queue per request | Histogram | Seconds |
|
||||
| `tgi_request_skipped_tokens` | Speculated tokens per request | Histogram | Count |
|
||||
| `tgi_request_success` | Number of successful requests | Counter | |
|
||||
| `tgi_request_validation_duration` | Time spent validating the request | Histogram | Seconds |
|
|
@ -1,22 +1,22 @@
|
|||
|
||||
# Supported Models and Hardware
|
||||
|
||||
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models are hardware are supported.
|
||||
Text Generation Inference enables serving optimized models on specific hardware for the highest performance. The following sections list which models (VLMs & LLMs) are supported.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
||||
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)
|
||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
|
||||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
|
||||
- [Mistral](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)
|
||||
- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)
|
||||
- [Phi](https://huggingface.co/microsoft/phi-1_5)
|
||||
|
@ -32,6 +32,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||
- [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct)
|
||||
- [Gpt2](https://huggingface.co/openai-community/gpt2)
|
||||
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,963 @@
|
|||
{
|
||||
"nodes": {
|
||||
"cachix": {
|
||||
"inputs": {
|
||||
"devenv": [
|
||||
"crate2nix"
|
||||
],
|
||||
"flake-compat": [
|
||||
"crate2nix"
|
||||
],
|
||||
"nixpkgs": "nixpkgs",
|
||||
"pre-commit-hooks": [
|
||||
"crate2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1709700175,
|
||||
"narHash": "sha256-A0/6ZjLmT9qdYzKHmevnEIC7G+GiZ4UCr8v0poRPzds=",
|
||||
"owner": "cachix",
|
||||
"repo": "cachix",
|
||||
"rev": "be97b37989f11b724197b5f4c7ffd78f12c8c4bf",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"ref": "latest",
|
||||
"repo": "cachix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"cachix_2": {
|
||||
"inputs": {
|
||||
"devenv": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable"
|
||||
],
|
||||
"flake-compat": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable"
|
||||
],
|
||||
"nixpkgs": "nixpkgs_2",
|
||||
"pre-commit-hooks": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1716549461,
|
||||
"narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=",
|
||||
"owner": "cachix",
|
||||
"repo": "cachix",
|
||||
"rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"ref": "latest",
|
||||
"repo": "cachix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"cachix_3": {
|
||||
"inputs": {
|
||||
"devenv": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable"
|
||||
],
|
||||
"flake-compat": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable"
|
||||
],
|
||||
"nixpkgs": "nixpkgs_3",
|
||||
"pre-commit-hooks": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1716549461,
|
||||
"narHash": "sha256-lHy5kgx6J8uD+16SO47dPrbob98sh+W1tf4ceSqPVK4=",
|
||||
"owner": "cachix",
|
||||
"repo": "cachix",
|
||||
"rev": "e2bb269fb8c0828d5d4d2d7b8d09ea85abcacbd4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"ref": "latest",
|
||||
"repo": "cachix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"crate2nix": {
|
||||
"inputs": {
|
||||
"cachix": "cachix",
|
||||
"crate2nix_stable": "crate2nix_stable",
|
||||
"devshell": "devshell_3",
|
||||
"flake-compat": "flake-compat_3",
|
||||
"flake-parts": "flake-parts_3",
|
||||
"nix-test-runner": "nix-test-runner_3",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"pre-commit-hooks": "pre-commit-hooks_3"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1723311214,
|
||||
"narHash": "sha256-xdGZQBEa1AC2us/sY3igS/CucWY6jErXsAvCFRhB2LI=",
|
||||
"owner": "nix-community",
|
||||
"repo": "crate2nix",
|
||||
"rev": "236f6addfd452a48be805819e3216af79e988fd5",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "crate2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"crate2nix_stable": {
|
||||
"inputs": {
|
||||
"cachix": "cachix_2",
|
||||
"crate2nix_stable": "crate2nix_stable_2",
|
||||
"devshell": "devshell_2",
|
||||
"flake-compat": "flake-compat_2",
|
||||
"flake-parts": "flake-parts_2",
|
||||
"nix-test-runner": "nix-test-runner_2",
|
||||
"nixpkgs": "nixpkgs_5",
|
||||
"pre-commit-hooks": "pre-commit-hooks_2"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1719760004,
|
||||
"narHash": "sha256-esWhRnt7FhiYq0CcIxw9pvH+ybOQmWBfHYMtleaMhBE=",
|
||||
"owner": "nix-community",
|
||||
"repo": "crate2nix",
|
||||
"rev": "1dee214bb20855fa3e1e7bb98d28922ddaff8c57",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"ref": "0.14.1",
|
||||
"repo": "crate2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"crate2nix_stable_2": {
|
||||
"inputs": {
|
||||
"cachix": "cachix_3",
|
||||
"crate2nix_stable": "crate2nix_stable_3",
|
||||
"devshell": "devshell",
|
||||
"flake-compat": "flake-compat",
|
||||
"flake-parts": "flake-parts",
|
||||
"nix-test-runner": "nix-test-runner",
|
||||
"nixpkgs": "nixpkgs_4",
|
||||
"pre-commit-hooks": "pre-commit-hooks"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1712821484,
|
||||
"narHash": "sha256-rGT3CW64cJS9nlnWPFWSc1iEa3dNZecVVuPVGzcsHe8=",
|
||||
"owner": "nix-community",
|
||||
"repo": "crate2nix",
|
||||
"rev": "42883afcad3823fa5811e967fb7bff54bc3c9d6d",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"ref": "0.14.0",
|
||||
"repo": "crate2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"crate2nix_stable_3": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1702842982,
|
||||
"narHash": "sha256-A9AowkHIjsy1a4LuiPiVP88FMxyCWK41flZEZOUuwQM=",
|
||||
"owner": "nix-community",
|
||||
"repo": "crate2nix",
|
||||
"rev": "75ac2973affa6b9b4f661a7b592cba6e4f51d426",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"ref": "0.12.0",
|
||||
"repo": "crate2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"devshell": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils_2",
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1717408969,
|
||||
"narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=",
|
||||
"owner": "numtide",
|
||||
"repo": "devshell",
|
||||
"rev": "1ebbe68d57457c8cae98145410b164b5477761f4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "devshell",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"devshell_2": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils_3",
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1717408969,
|
||||
"narHash": "sha256-Q0OEFqe35fZbbRPPRdrjTUUChKVhhWXz3T9ZSKmaoVY=",
|
||||
"owner": "numtide",
|
||||
"repo": "devshell",
|
||||
"rev": "1ebbe68d57457c8cae98145410b164b5477761f4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "devshell",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"devshell_3": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils_4",
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1711099426,
|
||||
"narHash": "sha256-HzpgM/wc3aqpnHJJ2oDqPBkNsqWbW0WfWUO8lKu8nGk=",
|
||||
"owner": "numtide",
|
||||
"repo": "devshell",
|
||||
"rev": "2d45b54ca4a183f2fdcf4b19c895b64fbf620ee8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "devshell",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"revCount": 57,
|
||||
"type": "tarball",
|
||||
"url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz"
|
||||
},
|
||||
"original": {
|
||||
"type": "tarball",
|
||||
"url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz"
|
||||
}
|
||||
},
|
||||
"flake-compat_2": {
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"revCount": 57,
|
||||
"type": "tarball",
|
||||
"url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz"
|
||||
},
|
||||
"original": {
|
||||
"type": "tarball",
|
||||
"url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz"
|
||||
}
|
||||
},
|
||||
"flake-compat_3": {
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"revCount": 57,
|
||||
"type": "tarball",
|
||||
"url": "https://api.flakehub.com/f/pinned/edolstra/flake-compat/1.0.1/018afb31-abd1-7bff-a5e4-cff7e18efb7a/source.tar.gz"
|
||||
},
|
||||
"original": {
|
||||
"type": "tarball",
|
||||
"url": "https://flakehub.com/f/edolstra/flake-compat/1.tar.gz"
|
||||
}
|
||||
},
|
||||
"flake-compat_4": {
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1719745305,
|
||||
"narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts_2": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1719745305,
|
||||
"narHash": "sha256-xwgjVUpqSviudEkpQnioeez1Uo2wzrsMaJKJClh+Bls=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "c3c5ecc05edc7dafba779c6c1a61cd08ac6583e9",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts_3": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"crate2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1712014858,
|
||||
"narHash": "sha256-sB4SWl2lX95bExY2gMFG5HIzvva5AVMJd4Igm+GpZNw=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "9126214d0a59633752a136528f5f3b9aa8565b7d",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1694529238,
|
||||
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_2": {
|
||||
"inputs": {
|
||||
"systems": "systems_2"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1701680307,
|
||||
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_3": {
|
||||
"inputs": {
|
||||
"systems": "systems_3"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1701680307,
|
||||
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_4": {
|
||||
"inputs": {
|
||||
"systems": "systems_4"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1701680307,
|
||||
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_5": {
|
||||
"inputs": {
|
||||
"systems": "systems_5"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_6": {
|
||||
"inputs": {
|
||||
"systems": "systems_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"gitignore": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable",
|
||||
"pre-commit-hooks",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1709087332,
|
||||
"narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"gitignore_2": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"pre-commit-hooks",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1709087332,
|
||||
"narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"gitignore_3": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"pre-commit-hooks",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1709087332,
|
||||
"narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"rev": "637db329424fd7e46cf4185293b9cc8c88c95394",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "gitignore.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix-filter": {
|
||||
"locked": {
|
||||
"lastModified": 1710156097,
|
||||
"narHash": "sha256-1Wvk8UP7PXdf8bCCaEoMnOT1qe5/Duqgj+rL8sRQsSM=",
|
||||
"owner": "numtide",
|
||||
"repo": "nix-filter",
|
||||
"rev": "3342559a24e85fc164b295c3444e8a139924675b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "nix-filter",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix-test-runner": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1588761593,
|
||||
"narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=",
|
||||
"owner": "stoeffel",
|
||||
"repo": "nix-test-runner",
|
||||
"rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "stoeffel",
|
||||
"repo": "nix-test-runner",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix-test-runner_2": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1588761593,
|
||||
"narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=",
|
||||
"owner": "stoeffel",
|
||||
"repo": "nix-test-runner",
|
||||
"rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "stoeffel",
|
||||
"repo": "nix-test-runner",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nix-test-runner_3": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1588761593,
|
||||
"narHash": "sha256-FKJykltAN/g3eIceJl4SfDnnyuH2jHImhMrXS2KvGIs=",
|
||||
"owner": "stoeffel",
|
||||
"repo": "nix-test-runner",
|
||||
"rev": "c45d45b11ecef3eb9d834c3b6304c05c49b06ca2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "stoeffel",
|
||||
"repo": "nix-test-runner",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1700612854,
|
||||
"narHash": "sha256-yrQ8osMD+vDLGFX7pcwsY/Qr5PUd6OmDMYJZzZi0+zc=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "19cbff58383a4ae384dea4d1d0c823d72b49d614",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1715534503,
|
||||
"narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "2057814051972fa1453ddfb0d98badbea9b83c06",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_3": {
|
||||
"locked": {
|
||||
"lastModified": 1715534503,
|
||||
"narHash": "sha256-5ZSVkFadZbFP1THataCaSf0JH2cAH3S29hU9rrxTEqk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "2057814051972fa1453ddfb0d98badbea9b83c06",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs_4": {
|
||||
"locked": {
|
||||
"lastModified": 1719506693,
|
||||
"narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=",
|
||||
"path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source",
|
||||
"rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a",
|
||||
"type": "path"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"nixpkgs_5": {
|
||||
"locked": {
|
||||
"lastModified": 1719506693,
|
||||
"narHash": "sha256-C8e9S7RzshSdHB7L+v9I51af1gDM5unhJ2xO1ywxNH8=",
|
||||
"path": "/nix/store/4p0avw1s3vf27hspgqsrqs37gxk4i83i-source",
|
||||
"rev": "b2852eb9365c6de48ffb0dc2c9562591f652242a",
|
||||
"type": "path"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
"nixpkgs_6": {
|
||||
"locked": {
|
||||
"lastModified": 1723912943,
|
||||
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=",
|
||||
"owner": "danieldk",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "cuda-12.4",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pre-commit-hooks": {
|
||||
"inputs": {
|
||||
"flake-compat": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable",
|
||||
"flake-compat"
|
||||
],
|
||||
"gitignore": "gitignore",
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
],
|
||||
"nixpkgs-stable": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1719259945,
|
||||
"narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=",
|
||||
"owner": "cachix",
|
||||
"repo": "pre-commit-hooks.nix",
|
||||
"rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"repo": "pre-commit-hooks.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pre-commit-hooks_2": {
|
||||
"inputs": {
|
||||
"flake-compat": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"flake-compat"
|
||||
],
|
||||
"gitignore": "gitignore_2",
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
],
|
||||
"nixpkgs-stable": [
|
||||
"crate2nix",
|
||||
"crate2nix_stable",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1719259945,
|
||||
"narHash": "sha256-F1h+XIsGKT9TkGO3omxDLEb/9jOOsI6NnzsXFsZhry4=",
|
||||
"owner": "cachix",
|
||||
"repo": "pre-commit-hooks.nix",
|
||||
"rev": "0ff4381bbb8f7a52ca4a851660fc7a437a4c6e07",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"repo": "pre-commit-hooks.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pre-commit-hooks_3": {
|
||||
"inputs": {
|
||||
"flake-compat": [
|
||||
"crate2nix",
|
||||
"flake-compat"
|
||||
],
|
||||
"flake-utils": "flake-utils_5",
|
||||
"gitignore": "gitignore_3",
|
||||
"nixpkgs": [
|
||||
"crate2nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"nixpkgs-stable": [
|
||||
"crate2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1712055707,
|
||||
"narHash": "sha256-4XLvuSIDZJGS17xEwSrNuJLL7UjDYKGJSbK1WWX2AK8=",
|
||||
"owner": "cachix",
|
||||
"repo": "pre-commit-hooks.nix",
|
||||
"rev": "e35aed5fda3cc79f88ed7f1795021e559582093a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "cachix",
|
||||
"repo": "pre-commit-hooks.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crate2nix": "crate2nix",
|
||||
"flake-utils": "flake-utils_6",
|
||||
"nix-filter": "nix-filter",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"rust-overlay": "rust-overlay",
|
||||
"tgi-nix": "tgi-nix"
|
||||
}
|
||||
},
|
||||
"rust-overlay": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1724638882,
|
||||
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_2": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_3": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_4": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_5": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_6": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"tgi-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat_4",
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1725011596,
|
||||
"narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=",
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"rev": "717c2b07e38538abf05237cca65b2d1363c2c9af",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
{
|
||||
inputs = {
|
||||
crate2nix = {
|
||||
url = "github:nix-community/crate2nix";
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
nix-filter.url = "github:numtide/nix-filter";
|
||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
rust-overlay = {
|
||||
url = "github:oxalica/rust-overlay";
|
||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
};
|
||||
};
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
crate2nix,
|
||||
nix-filter,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
rust-overlay,
|
||||
tgi-nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
cargoNix = crate2nix.tools.${system}.appliedCargoNix {
|
||||
name = "tgi";
|
||||
src = ./.;
|
||||
additionalCargoNixArgs = [ "--all-features" ];
|
||||
};
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
inherit (tgi-nix.lib) config;
|
||||
overlays = [
|
||||
rust-overlay.overlays.default
|
||||
tgi-nix.overlays.default
|
||||
];
|
||||
};
|
||||
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
||||
benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {
|
||||
inherit crateOverrides;
|
||||
};
|
||||
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
|
||||
inherit crateOverrides;
|
||||
};
|
||||
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
|
||||
inherit crateOverrides;
|
||||
};
|
||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||
in
|
||||
{
|
||||
devShells = with pkgs; rec {
|
||||
default = pure;
|
||||
|
||||
pure = mkShell {
|
||||
buildInputs = [
|
||||
benchmark
|
||||
launcher
|
||||
router
|
||||
server
|
||||
];
|
||||
};
|
||||
|
||||
impure = mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
openssl.dev
|
||||
pkg-config
|
||||
(rust-bin.stable.latest.default.override {
|
||||
extensions = [
|
||||
"rust-analyzer"
|
||||
"rust-src"
|
||||
];
|
||||
})
|
||||
protobuf
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
venvShellHook
|
||||
docker
|
||||
pip
|
||||
ipdb
|
||||
pyright
|
||||
pytest
|
||||
pytest-asyncio
|
||||
ruff
|
||||
syrupy
|
||||
]);
|
||||
|
||||
inputsFrom = [ server ];
|
||||
|
||||
venvDir = "./.venv";
|
||||
|
||||
postVenvCreation = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
( cd server ; python -m pip install --no-dependencies -e . )
|
||||
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||
'';
|
||||
postShellHook = ''
|
||||
unset SOURCE_DATE_EPOCH
|
||||
export PATH=$PATH:~/.cargo/bin
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
packages.default = pkgs.writeShellApplication {
|
||||
name = "text-generation-inference";
|
||||
runtimeInputs = [
|
||||
server
|
||||
router
|
||||
];
|
||||
text = ''
|
||||
${launcher}/bin/text-generation-launcher "$@"
|
||||
'';
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
|
@ -64,6 +64,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
self,
|
||||
data,
|
||||
*,
|
||||
include=None,
|
||||
exclude=None,
|
||||
matcher=None,
|
||||
):
|
||||
|
@ -79,7 +80,12 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
data = [d.model_dump() for d in data]
|
||||
|
||||
data = self._filter(
|
||||
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
|
||||
data=data,
|
||||
depth=0,
|
||||
path=(),
|
||||
exclude=exclude,
|
||||
include=include,
|
||||
matcher=matcher,
|
||||
)
|
||||
return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
|
||||
|
||||
|
@ -118,6 +124,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
and token.text == other.text
|
||||
and (
|
||||
self.ignore_logprob
|
||||
or (token.logprob == other.logprob and token.logprob is None)
|
||||
or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)
|
||||
)
|
||||
and token.special == other.special
|
||||
|
@ -256,7 +263,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
|
|||
|
||||
class LauncherHandle:
|
||||
def __init__(self, port: int):
|
||||
self.client = AsyncClient(f"http://localhost:{port}")
|
||||
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
||||
|
||||
def _inner_health(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -41,22 +41,22 @@
|
|||
},
|
||||
{
|
||||
"id": 1669,
|
||||
"logprob": -1.5664062,
|
||||
"logprob": -1.5595703,
|
||||
"text": " il"
|
||||
},
|
||||
{
|
||||
"id": 11580,
|
||||
"logprob": -0.94189453,
|
||||
"logprob": -0.9428711,
|
||||
"text": " faut"
|
||||
},
|
||||
{
|
||||
"id": 3913,
|
||||
"logprob": -3.6816406,
|
||||
"logprob": -3.703125,
|
||||
"text": " tout"
|
||||
},
|
||||
{
|
||||
"id": 39261,
|
||||
"logprob": -1.7753906,
|
||||
"logprob": -1.7763672,
|
||||
"text": " d'abord"
|
||||
}
|
||||
],
|
||||
|
@ -64,7 +64,7 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 578,
|
||||
"logprob": -1.6318359,
|
||||
"logprob": -1.7822266,
|
||||
"special": false,
|
||||
"text": " le"
|
||||
},
|
||||
|
@ -76,7 +76,7 @@
|
|||
},
|
||||
{
|
||||
"id": 7735,
|
||||
"logprob": -2.4355469,
|
||||
"logprob": -2.4199219,
|
||||
"special": false,
|
||||
"text": " fond"
|
||||
},
|
||||
|
@ -88,19 +88,19 @@
|
|||
},
|
||||
{
|
||||
"id": 693,
|
||||
"logprob": -2.4472656,
|
||||
"logprob": -2.4628906,
|
||||
"special": false,
|
||||
"text": " à"
|
||||
},
|
||||
{
|
||||
"id": 366,
|
||||
"logprob": -1.1972656,
|
||||
"logprob": -1.1308594,
|
||||
"special": false,
|
||||
"text": " la"
|
||||
},
|
||||
{
|
||||
"id": 48844,
|
||||
"logprob": -1.7890625,
|
||||
"logprob": -1.7900391,
|
||||
"special": false,
|
||||
"text": " cass"
|
||||
},
|
||||
|
@ -118,7 +118,7 @@
|
|||
},
|
||||
{
|
||||
"id": 2940,
|
||||
"logprob": -1.9335938,
|
||||
"logprob": -1.9306641,
|
||||
"special": false,
|
||||
"text": " avec"
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas",
|
||||
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
|
@ -13,14 +13,14 @@
|
|||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1716553098,
|
||||
"created": 1724792495,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.5-dev0-native",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 62,
|
||||
"total_tokens": 162
|
||||
"prompt_tokens": 61,
|
||||
"total_tokens": 161
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,11 +8,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -23,11 +23,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -38,11 +38,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -53,11 +53,11 @@
|
|||
"text": "hd"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -68,11 +68,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -83,11 +83,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -98,11 +98,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -113,11 +113,11 @@
|
|||
"text": "aho"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -128,11 +128,11 @@
|
|||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -143,11 +143,11 @@
|
|||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -158,11 +158,11 @@
|
|||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -173,11 +173,11 @@
|
|||
"text": "ima"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -188,11 +188,11 @@
|
|||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -203,11 +203,11 @@
|
|||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -218,11 +218,11 @@
|
|||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -233,11 +233,11 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -248,11 +248,11 @@
|
|||
"text": " Sarah"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -263,11 +263,11 @@
|
|||
"text": " Yes"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -278,11 +278,11 @@
|
|||
"text": " And"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -293,11 +293,11 @@
|
|||
"text": "i"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -308,11 +308,11 @@
|
|||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -323,11 +323,11 @@
|
|||
"text": ","
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -338,11 +338,11 @@
|
|||
"text": " what"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -353,11 +353,11 @@
|
|||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -368,11 +368,11 @@
|
|||
"text": "s"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -383,11 +383,11 @@
|
|||
"text": " Moh"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -398,11 +398,11 @@
|
|||
"text": " is"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -413,11 +413,11 @@
|
|||
"text": "m"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -428,11 +428,11 @@
|
|||
"text": " Room"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -443,11 +443,11 @@
|
|||
"text": "s"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -458,11 +458,11 @@
|
|||
"text": " the"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -473,11 +473,11 @@
|
|||
"text": " tired"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -488,11 +488,11 @@
|
|||
"text": ":"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -503,11 +503,11 @@
|
|||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -518,11 +518,11 @@
|
|||
"text": " capital"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
|
@ -530,73 +530,73 @@
|
|||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " of"
|
||||
"text": ","
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " She"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"finish_reason": "length",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " scale"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"finish_reason": "length",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " of"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"finish_reason": "length",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " being"
|
||||
"text": " its"
|
||||
}
|
||||
],
|
||||
"created": 1713284431,
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native"
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
},
|
||||
{
|
||||
"id": 3102,
|
||||
"logprob": -11.1875,
|
||||
"logprob": -11.25,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -24,66 +24,66 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 185,
|
||||
"logprob": -1.5546875,
|
||||
"logprob": -1.546875,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 549,
|
||||
"logprob": -2.84375,
|
||||
"logprob": -2.859375,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 1727,
|
||||
"logprob": -2.34375,
|
||||
"logprob": -2.484375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3102,
|
||||
"logprob": -0.8359375,
|
||||
"logprob": -0.83203125,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.0859375,
|
||||
"logprob": -1.1484375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -1.5390625,
|
||||
"id": 245,
|
||||
"logprob": -1.578125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1022,
|
||||
"logprob": -1.1875,
|
||||
"id": 3412,
|
||||
"logprob": -2.578125,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
"text": " document"
|
||||
},
|
||||
{
|
||||
"id": 3458,
|
||||
"logprob": -0.35546875,
|
||||
"id": 344,
|
||||
"logprob": -1.125,
|
||||
"special": false,
|
||||
"text": " step"
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 279,
|
||||
"logprob": -0.8828125,
|
||||
"id": 317,
|
||||
"logprob": -1.6953125,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 254,
|
||||
"logprob": -0.71484375,
|
||||
"id": 1222,
|
||||
"logprob": -1.71875,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
"text": " used"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nThe test request is the first step in the"
|
||||
"generated_text": "\nThe test request is a document that is used"
|
||||
}
|
||||
|
|
|
@ -37,56 +37,56 @@
|
|||
},
|
||||
{
|
||||
"id": 1727,
|
||||
"logprob": -2.359375,
|
||||
"logprob": -2.4375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3102,
|
||||
"logprob": -0.83203125,
|
||||
"logprob": -0.83984375,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.125,
|
||||
"logprob": -1.1328125,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 245,
|
||||
"logprob": -1.5703125,
|
||||
"id": 254,
|
||||
"logprob": -1.515625,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3412,
|
||||
"logprob": -2.578125,
|
||||
"id": 1022,
|
||||
"logprob": -1.15625,
|
||||
"special": false,
|
||||
"text": " document"
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 344,
|
||||
"logprob": -1.125,
|
||||
"id": 3458,
|
||||
"logprob": -0.3671875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " step"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.6953125,
|
||||
"id": 279,
|
||||
"logprob": -0.88671875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 1222,
|
||||
"logprob": -1.75,
|
||||
"id": 254,
|
||||
"logprob": -0.69140625,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nThe test request is a document that is used"
|
||||
"generated_text": "\nThe test request is the first step in the"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -126,56 +126,56 @@
|
|||
},
|
||||
{
|
||||
"id": 1727,
|
||||
"logprob": -2.359375,
|
||||
"logprob": -2.4375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3102,
|
||||
"logprob": -0.83203125,
|
||||
"logprob": -0.83984375,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.125,
|
||||
"logprob": -1.1328125,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 245,
|
||||
"logprob": -1.5703125,
|
||||
"id": 254,
|
||||
"logprob": -1.515625,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3412,
|
||||
"logprob": -2.578125,
|
||||
"id": 1022,
|
||||
"logprob": -1.15625,
|
||||
"special": false,
|
||||
"text": " document"
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 344,
|
||||
"logprob": -1.125,
|
||||
"id": 3458,
|
||||
"logprob": -0.3671875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " step"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.6953125,
|
||||
"id": 279,
|
||||
"logprob": -0.88671875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 1222,
|
||||
"logprob": -1.75,
|
||||
"id": 254,
|
||||
"logprob": -0.69140625,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nThe test request is a document that is used"
|
||||
"generated_text": "\nThe test request is the first step in the"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -215,56 +215,56 @@
|
|||
},
|
||||
{
|
||||
"id": 1727,
|
||||
"logprob": -2.359375,
|
||||
"logprob": -2.4375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3102,
|
||||
"logprob": -0.83203125,
|
||||
"logprob": -0.83984375,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.125,
|
||||
"logprob": -1.1328125,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 245,
|
||||
"logprob": -1.5703125,
|
||||
"id": 254,
|
||||
"logprob": -1.515625,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3412,
|
||||
"logprob": -2.578125,
|
||||
"id": 1022,
|
||||
"logprob": -1.15625,
|
||||
"special": false,
|
||||
"text": " document"
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 344,
|
||||
"logprob": -1.125,
|
||||
"id": 3458,
|
||||
"logprob": -0.3671875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " step"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.6953125,
|
||||
"id": 279,
|
||||
"logprob": -0.88671875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 1222,
|
||||
"logprob": -1.75,
|
||||
"id": 254,
|
||||
"logprob": -0.69140625,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nThe test request is a document that is used"
|
||||
"generated_text": "\nThe test request is the first step in the"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -304,55 +304,55 @@
|
|||
},
|
||||
{
|
||||
"id": 1727,
|
||||
"logprob": -2.359375,
|
||||
"logprob": -2.4375,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
},
|
||||
{
|
||||
"id": 3102,
|
||||
"logprob": -0.83203125,
|
||||
"logprob": -0.83984375,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.125,
|
||||
"logprob": -1.1328125,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 245,
|
||||
"logprob": -1.5703125,
|
||||
"id": 254,
|
||||
"logprob": -1.515625,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3412,
|
||||
"logprob": -2.578125,
|
||||
"id": 1022,
|
||||
"logprob": -1.15625,
|
||||
"special": false,
|
||||
"text": " document"
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 344,
|
||||
"logprob": -1.125,
|
||||
"id": 3458,
|
||||
"logprob": -0.3671875,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
"text": " step"
|
||||
},
|
||||
{
|
||||
"id": 317,
|
||||
"logprob": -1.6953125,
|
||||
"id": 279,
|
||||
"logprob": -0.88671875,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 1222,
|
||||
"logprob": -1.75,
|
||||
"id": 254,
|
||||
"logprob": -0.69140625,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
"text": " the"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nThe test request is a document that is used"
|
||||
"generated_text": "\nThe test request is the first step in the"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 5,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
|
@ -16,7 +16,7 @@
|
|||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -10.375,
|
||||
"logprob": -10.4375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -29,61 +29,31 @@
|
|||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 2209,
|
||||
"logprob": -2.78125,
|
||||
"id": 923,
|
||||
"logprob": -2.84375,
|
||||
"special": false,
|
||||
"text": " Is"
|
||||
"text": " add"
|
||||
},
|
||||
{
|
||||
"id": 279,
|
||||
"logprob": -0.6328125,
|
||||
"id": 264,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 734,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " function"
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.34179688,
|
||||
"logprob": -0.31640625,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 4110,
|
||||
"logprob": -2.359375,
|
||||
"id": 1985,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Create"
|
||||
},
|
||||
{
|
||||
"id": 7575,
|
||||
"logprob": -2.1875,
|
||||
"special": false,
|
||||
"text": "Process"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.07910156,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": -0.83203125,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 12468,
|
||||
"logprob": -1.8203125,
|
||||
"special": false,
|
||||
"text": " Win"
|
||||
"text": "test"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
|
||||
"generated_text": "Test request: add a \"test"
|
||||
}
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
},
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": -9.421875,
|
||||
"logprob": -9.5625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -10.546875,
|
||||
"logprob": -10.375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -25,61 +25,61 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -2.1816406,
|
||||
"logprob": -2.15625,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 279,
|
||||
"logprob": -2.6992188,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -3.6308594,
|
||||
"logprob": -3.640625,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 679,
|
||||
"logprob": -1.7988281,
|
||||
"logprob": -1.703125,
|
||||
"special": false,
|
||||
"text": "201"
|
||||
},
|
||||
{
|
||||
"id": 24,
|
||||
"logprob": -1.3535156,
|
||||
"logprob": -1.421875,
|
||||
"special": false,
|
||||
"text": "9"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"logprob": -2.0058594,
|
||||
"logprob": -2.03125,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2366,
|
||||
"logprob": -0.45410156,
|
||||
"logprob": -0.49023438,
|
||||
"special": false,
|
||||
"text": "202"
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"logprob": -0.037109375,
|
||||
"logprob": -0.041503906,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 2978,
|
||||
"logprob": -0.8095703,
|
||||
"logprob": -0.87109375,
|
||||
"special": false,
|
||||
"text": " school"
|
||||
},
|
||||
{
|
||||
"id": 1060,
|
||||
"logprob": -0.013053894,
|
||||
"logprob": -0.012939453,
|
||||
"special": false,
|
||||
"text": " year"
|
||||
}
|
||||
|
@ -101,12 +101,12 @@
|
|||
},
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": -9.421875,
|
||||
"logprob": -9.5625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -10.546875,
|
||||
"logprob": -10.375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -114,61 +114,61 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -2.1816406,
|
||||
"logprob": -2.15625,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 279,
|
||||
"logprob": -2.6992188,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -3.6308594,
|
||||
"logprob": -3.640625,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 679,
|
||||
"logprob": -1.7988281,
|
||||
"logprob": -1.703125,
|
||||
"special": false,
|
||||
"text": "201"
|
||||
},
|
||||
{
|
||||
"id": 24,
|
||||
"logprob": -1.3535156,
|
||||
"logprob": -1.421875,
|
||||
"special": false,
|
||||
"text": "9"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"logprob": -2.0058594,
|
||||
"logprob": -2.03125,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2366,
|
||||
"logprob": -0.45410156,
|
||||
"logprob": -0.49023438,
|
||||
"special": false,
|
||||
"text": "202"
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"logprob": -0.037109375,
|
||||
"logprob": -0.041503906,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 2978,
|
||||
"logprob": -0.8095703,
|
||||
"logprob": -0.87109375,
|
||||
"special": false,
|
||||
"text": " school"
|
||||
},
|
||||
{
|
||||
"id": 1060,
|
||||
"logprob": -0.013053894,
|
||||
"logprob": -0.012939453,
|
||||
"special": false,
|
||||
"text": " year"
|
||||
}
|
||||
|
@ -190,12 +190,12 @@
|
|||
},
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": -9.421875,
|
||||
"logprob": -9.5625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -10.546875,
|
||||
"logprob": -10.375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -203,61 +203,61 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -2.1816406,
|
||||
"logprob": -2.15625,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 279,
|
||||
"logprob": -2.6992188,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -3.6308594,
|
||||
"logprob": -3.640625,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 679,
|
||||
"logprob": -1.7988281,
|
||||
"logprob": -1.703125,
|
||||
"special": false,
|
||||
"text": "201"
|
||||
},
|
||||
{
|
||||
"id": 24,
|
||||
"logprob": -1.3535156,
|
||||
"logprob": -1.421875,
|
||||
"special": false,
|
||||
"text": "9"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"logprob": -2.0058594,
|
||||
"logprob": -2.03125,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2366,
|
||||
"logprob": -0.45410156,
|
||||
"logprob": -0.49023438,
|
||||
"special": false,
|
||||
"text": "202"
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"logprob": -0.037109375,
|
||||
"logprob": -0.041503906,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 2978,
|
||||
"logprob": -0.8095703,
|
||||
"logprob": -0.87109375,
|
||||
"special": false,
|
||||
"text": " school"
|
||||
},
|
||||
{
|
||||
"id": 1060,
|
||||
"logprob": -0.013053894,
|
||||
"logprob": -0.012939453,
|
||||
"special": false,
|
||||
"text": " year"
|
||||
}
|
||||
|
@ -279,12 +279,12 @@
|
|||
},
|
||||
{
|
||||
"id": 2323,
|
||||
"logprob": -9.421875,
|
||||
"logprob": -9.5625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -10.546875,
|
||||
"logprob": -10.375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
|
@ -292,61 +292,61 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -2.1816406,
|
||||
"logprob": -2.15625,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 279,
|
||||
"logprob": -2.6992188,
|
||||
"logprob": -2.703125,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 220,
|
||||
"logprob": -3.6308594,
|
||||
"logprob": -3.640625,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 679,
|
||||
"logprob": -1.7988281,
|
||||
"logprob": -1.703125,
|
||||
"special": false,
|
||||
"text": "201"
|
||||
},
|
||||
{
|
||||
"id": 24,
|
||||
"logprob": -1.3535156,
|
||||
"logprob": -1.421875,
|
||||
"special": false,
|
||||
"text": "9"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"logprob": -2.0058594,
|
||||
"logprob": -2.03125,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 2366,
|
||||
"logprob": -0.45410156,
|
||||
"logprob": -0.49023438,
|
||||
"special": false,
|
||||
"text": "202"
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"logprob": -0.037109375,
|
||||
"logprob": -0.041503906,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 2978,
|
||||
"logprob": -0.8095703,
|
||||
"logprob": -0.87109375,
|
||||
"special": false,
|
||||
"text": " school"
|
||||
},
|
||||
{
|
||||
"id": 1060,
|
||||
"logprob": -0.013053894,
|
||||
"logprob": -0.012939453,
|
||||
"special": false,
|
||||
"text": " year"
|
||||
}
|
||||
|
|
|
@ -8,13 +8,13 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 54901,
|
||||
"logprob": -0.72753906,
|
||||
"logprob": -0.84765625,
|
||||
"special": false,
|
||||
"text": "beach"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.011009216,
|
||||
"logprob": -0.008666992,
|
||||
"special": true,
|
||||
"text": "<eos>"
|
||||
}
|
||||
|
|
|
@ -19,25 +19,25 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.19421387,
|
||||
"logprob": -0.28955078,
|
||||
"special": false,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 3758,
|
||||
"logprob": -0.62597656,
|
||||
"logprob": -0.7739258,
|
||||
"special": false,
|
||||
"text": " send"
|
||||
},
|
||||
{
|
||||
"id": 1366,
|
||||
"logprob": -0.87060547,
|
||||
"logprob": -0.85253906,
|
||||
"special": false,
|
||||
"text": " data"
|
||||
},
|
||||
{
|
||||
"id": 625,
|
||||
"logprob": -0.88427734,
|
||||
"logprob": -0.8984375,
|
||||
"special": false,
|
||||
"text": " over"
|
||||
},
|
||||
|
@ -49,7 +49,7 @@
|
|||
},
|
||||
{
|
||||
"id": 3127,
|
||||
"logprob": -1.9462891,
|
||||
"logprob": -1.9404297,
|
||||
"special": false,
|
||||
"text": " network"
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
},
|
||||
{
|
||||
"id": 100,
|
||||
"logprob": -0.38549805,
|
||||
"logprob": -0.38305664,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
|
@ -29,7 +29,7 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 2284,
|
||||
"logprob": -0.31323242,
|
||||
"logprob": -0.296875,
|
||||
"special": false,
|
||||
"text": "():"
|
||||
},
|
||||
|
@ -59,19 +59,19 @@
|
|||
},
|
||||
{
|
||||
"id": 10914,
|
||||
"logprob": -0.7817383,
|
||||
"logprob": -0.7734375,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 16013,
|
||||
"logprob": -0.6328125,
|
||||
"logprob": -0.61816406,
|
||||
"special": false,
|
||||
"text": "!\")"
|
||||
},
|
||||
{
|
||||
"id": 222,
|
||||
"logprob": -0.0619812,
|
||||
"logprob": -0.054870605,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
|
@ -83,7 +83,7 @@
|
|||
},
|
||||
{
|
||||
"id": 610,
|
||||
"logprob": -0.4086914,
|
||||
"logprob": -0.4152832,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
|
@ -113,7 +113,7 @@
|
|||
},
|
||||
{
|
||||
"id": 444,
|
||||
"logprob": -0.21826172,
|
||||
"logprob": -0.21618652,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
|
@ -173,7 +173,7 @@
|
|||
},
|
||||
{
|
||||
"id": 11571,
|
||||
"logprob": -0.10021973,
|
||||
"logprob": -0.08892822,
|
||||
"special": false,
|
||||
"text": "!\""
|
||||
},
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -11,57 +11,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -8.5859375,
|
||||
"logprob": -8.9453125,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -7.5859375,
|
||||
"logprob": -8.8515625,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.2668457,
|
||||
"logprob": -0.21875,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -1.6416016,
|
||||
"logprob": -1.2773438,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.22705078,
|
||||
"logprob": -0.25195312,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.2304688,
|
||||
"logprob": -4.8203125,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0976562,
|
||||
"logprob": -3.7734375,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.1044922,
|
||||
"logprob": -0.8310547,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.14294434,
|
||||
"logprob": -0.22766113,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.32299805,
|
||||
"logprob": -0.46240234,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.8164062,
|
||||
"logprob": -3.0234375,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -69,126 +69,18 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.1282959,
|
||||
"logprob": -0.04626465,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1524,
|
||||
"logprob": -0.97998047,
|
||||
"special": false,
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.7006836,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 14883,
|
||||
"logprob": -2.1933594,
|
||||
"special": false,
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -0.2697754,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -0.0836792,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -0.018737793,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 5651,
|
||||
"logprob": -0.028640747,
|
||||
"special": false,
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.29467773,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 312,
|
||||
"logprob": -0.31518555,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1149,
|
||||
"logprob": -0.20605469,
|
||||
"special": false,
|
||||
"text": " list"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.23254395,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 7515,
|
||||
"logprob": -0.4489746,
|
||||
"special": false,
|
||||
"text": " numbers"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.6044922,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 446,
|
||||
"logprob": -0.63964844,
|
||||
"special": false,
|
||||
"text": "\n\n "
|
||||
},
|
||||
{
|
||||
"id": 499,
|
||||
"logprob": -1.1953125,
|
||||
"special": false,
|
||||
"text": " :"
|
||||
},
|
||||
{
|
||||
"id": 753,
|
||||
"logprob": -0.03515625,
|
||||
"special": false,
|
||||
"text": "param"
|
||||
},
|
||||
{
|
||||
"id": 498,
|
||||
"logprob": -0.06311035,
|
||||
"special": false,
|
||||
"text": " L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -0.003414154,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.3310547,
|
||||
"special": false,
|
||||
"text": " List"
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List"
|
||||
"generated_text": "\n "
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -11,57 +11,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -8.5859375,
|
||||
"logprob": -8.9453125,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -7.5898438,
|
||||
"logprob": -8.859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.26586914,
|
||||
"logprob": -0.21984863,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -1.6347656,
|
||||
"logprob": -1.2861328,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.22705078,
|
||||
"logprob": -0.25219727,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.2382812,
|
||||
"logprob": -4.8007812,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0996094,
|
||||
"logprob": -3.7949219,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.1025391,
|
||||
"logprob": -0.8046875,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.14294434,
|
||||
"logprob": -0.22424316,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.32226562,
|
||||
"logprob": -0.46191406,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.8164062,
|
||||
"logprob": -3.0253906,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -74,121 +74,13 @@
|
|||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -1.3134766,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 11665,
|
||||
"logprob": -0.10021973,
|
||||
"special": false,
|
||||
"text": " reduce"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 5962,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "lambda"
|
||||
},
|
||||
{
|
||||
"id": 816,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " x"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 533,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " y"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 816,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " x"
|
||||
},
|
||||
{
|
||||
"id": 319,
|
||||
"logprob": -0.42871094,
|
||||
"special": false,
|
||||
"text": " *"
|
||||
},
|
||||
{
|
||||
"id": 533,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " y"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 498,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " L"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 1115,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " **"
|
||||
},
|
||||
{
|
||||
"id": 308,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " ("
|
||||
},
|
||||
{
|
||||
"id": 35,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.31323242,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 34,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0"
|
||||
"generated_text": "\n "
|
||||
}
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -12,57 +12,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -8.5859375,
|
||||
"logprob": -8.9453125,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -7.5820312,
|
||||
"logprob": -8.8515625,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.26708984,
|
||||
"logprob": -0.22033691,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -1.6386719,
|
||||
"logprob": -1.2939453,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.22717285,
|
||||
"logprob": -0.25268555,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.234375,
|
||||
"logprob": -4.796875,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.1015625,
|
||||
"logprob": -3.796875,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.1083984,
|
||||
"logprob": -0.8066406,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.14294434,
|
||||
"logprob": -0.22644043,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.32592773,
|
||||
"logprob": -0.46166992,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.8164062,
|
||||
"logprob": -3.0253906,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -70,74 +70,26 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.12817383,
|
||||
"logprob": -0.046844482,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1524,
|
||||
"logprob": -0.9863281,
|
||||
"special": false,
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.7011719,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 14883,
|
||||
"logprob": -2.2050781,
|
||||
"special": false,
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -0.2668457,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -0.08465576,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -0.019012451,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 5651,
|
||||
"logprob": -0.028625488,
|
||||
"special": false,
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.29418945,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 312,
|
||||
"logprob": -0.3161621,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
"generated_text": "\n "
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -146,57 +98,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -8.5859375,
|
||||
"logprob": -8.9375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -7.59375,
|
||||
"logprob": -8.8515625,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.26953125,
|
||||
"logprob": -0.21826172,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -1.640625,
|
||||
"logprob": -1.2871094,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.22705078,
|
||||
"logprob": -0.25390625,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.234375,
|
||||
"logprob": -4.8085938,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.1132812,
|
||||
"logprob": -3.7890625,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.1123047,
|
||||
"logprob": -0.8076172,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.14294434,
|
||||
"logprob": -0.22302246,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.32299805,
|
||||
"logprob": -0.46435547,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.8164062,
|
||||
"logprob": -3.0234375,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -204,74 +156,26 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.12854004,
|
||||
"logprob": -0.046722412,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1524,
|
||||
"logprob": -0.9897461,
|
||||
"special": false,
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.69970703,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 14883,
|
||||
"logprob": -2.2050781,
|
||||
"special": false,
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -0.2668457,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -0.08496094,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -0.019012451,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 5651,
|
||||
"logprob": -0.029037476,
|
||||
"special": false,
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.2939453,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 312,
|
||||
"logprob": -0.31591797,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
"generated_text": "\n "
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -280,57 +184,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -8.5859375,
|
||||
"logprob": -8.9453125,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -7.5859375,
|
||||
"logprob": -8.8515625,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.26586914,
|
||||
"logprob": -0.21813965,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -1.6347656,
|
||||
"logprob": -1.2744141,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.22766113,
|
||||
"logprob": -0.2512207,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.2265625,
|
||||
"logprob": -4.8046875,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0976562,
|
||||
"logprob": -3.7851562,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.1025391,
|
||||
"logprob": -0.81396484,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.1427002,
|
||||
"logprob": -0.22570801,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.32592773,
|
||||
"logprob": -0.46044922,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.8164062,
|
||||
"logprob": -3.0234375,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -338,74 +242,26 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.13012695,
|
||||
"logprob": -0.04650879,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1524,
|
||||
"logprob": -0.98046875,
|
||||
"special": false,
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.69921875,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 14883,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -0.2668457,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -0.083496094,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -0.01902771,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 5651,
|
||||
"logprob": -0.029006958,
|
||||
"special": false,
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.29248047,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 312,
|
||||
"logprob": -0.3161621,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
"generated_text": "\n "
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 2,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
|
@ -414,57 +270,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -8.5859375,
|
||||
"logprob": -8.9453125,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -7.5859375,
|
||||
"logprob": -8.8515625,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.26904297,
|
||||
"logprob": -0.21960449,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -1.6386719,
|
||||
"logprob": -1.2890625,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.22705078,
|
||||
"logprob": -0.25073242,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.234375,
|
||||
"logprob": -4.8085938,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.1132812,
|
||||
"logprob": -3.8046875,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.1074219,
|
||||
"logprob": -0.8071289,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.14477539,
|
||||
"logprob": -0.22570801,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.3256836,
|
||||
"logprob": -0.46118164,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.8027344,
|
||||
"logprob": -3.0097656,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -472,67 +328,19 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.12915039,
|
||||
"logprob": -0.046539307,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 1524,
|
||||
"logprob": -0.98535156,
|
||||
"special": false,
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.69921875,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 14883,
|
||||
"logprob": -2.2011719,
|
||||
"special": false,
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -0.26708984,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -0.08502197,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -0.019012451,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 5651,
|
||||
"logprob": -0.028625488,
|
||||
"special": false,
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.29589844,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 312,
|
||||
"logprob": -0.31591797,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"special": true,
|
||||
"text": "<|endoftext|>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
"generated_text": "\n "
|
||||
}
|
||||
]
|
||||
|
|
|
@ -30,19 +30,19 @@
|
|||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37573242,
|
||||
"logprob": -0.38061523,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 633,
|
||||
"logprob": -0.09161377,
|
||||
"logprob": -0.09301758,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 4480,
|
||||
"logprob": -0.26171875,
|
||||
"logprob": -0.26782227,
|
||||
"special": false,
|
||||
"text": " feature"
|
||||
},
|
||||
|
@ -78,7 +78,7 @@
|
|||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"logprob": -0.10632324,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -26,13 +26,13 @@
|
|||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -0.4716797,
|
||||
"logprob": -0.46948242,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 261,
|
||||
"logprob": -0.044677734,
|
||||
"logprob": -0.15307617,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
|
@ -56,7 +56,7 @@
|
|||
},
|
||||
{
|
||||
"id": 35622,
|
||||
"logprob": -1.1630859,
|
||||
"logprob": -1.2998047,
|
||||
"special": false,
|
||||
"text": " cloud"
|
||||
},
|
||||
|
|
|
@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
|
|||
print(repr(response.choices[0].message.content))
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas"
|
||||
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
|
|
@ -21,7 +21,6 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
|
|||
return flash_llama_exl2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||
|
@ -33,7 +32,6 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
|||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_all_params(
|
||||
|
@ -60,7 +58,6 @@ async def test_flash_llama_exl2_all_params(
|
|||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_load(
|
||||
|
|
|
@ -21,6 +21,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
|
|||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.generated_text == " for the 2019-2020 school year"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
@ -57,6 +58,8 @@ async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_sna
|
|||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses[0].generated_text == " for the 2019-2020 school year"
|
||||
assert all(
|
||||
[r.generated_text == responses[0].generated_text for r in responses]
|
||||
), f"Different messages : {[r.generated_text for r in responses]}"
|
||||
assert responses == response_snapshot
|
||||
|
|
|
@ -21,7 +21,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
|||
max_new_tokens=20,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response.details.generated_tokens == 2
|
||||
assert response == generous_response_snapshot
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ async def test_flash_starcoder_gptq_default_params(
|
|||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response.details.generated_tokens == 2
|
||||
assert response == generous_response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -98,6 +98,6 @@ async def test_grammar_response_format_llama_error_if_tools_not_installed(
|
|||
# 422 means the server was unable to process the request because it contains invalid data.
|
||||
assert response.status_code == 422
|
||||
assert response.json() == {
|
||||
"error": "Grammar and tools are mutually exclusive",
|
||||
"error_type": "grammar and tools",
|
||||
"error": "Tool error: Grammar and tools are mutually exclusive",
|
||||
"error_type": "tool_error",
|
||||
}
|
||||
|
|
|
@ -62,6 +62,7 @@ async def test_mamba_load(
|
|||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
||||
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opt_sharded_handle(launcher):
|
||||
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def opt_sharded(opt_sharded_handle):
|
||||
await opt_sharded_handle.health(300)
|
||||
return opt_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_opt(opt_sharded):
|
||||
pass
|
|
@ -36,6 +36,7 @@ tools = [
|
|||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -62,13 +63,13 @@ tools = [
|
|||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
|
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"id": "0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
|
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
|
|||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice="auto",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
|
|||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"id": "0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
|
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
|
|||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
|
|||
assert response.choices[0].message.content is None
|
||||
assert response.choices[0].message.tool_calls == [
|
||||
{
|
||||
"id": 0,
|
||||
"id": "0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "get_current_weather",
|
||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
||||
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
|
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
|
|||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
|
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
|
|||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 38
|
||||
assert count == 48
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Takes too long to run")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||
|
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=8,
|
||||
seed=24,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
|
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
|||
)
|
||||
|
||||
assert responses.choices[0].message.content is None
|
||||
assert responses.choices[0].message.tool_calls == [
|
||||
{
|
||||
"function": {
|
||||
"arguments": {
|
||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
||||
},
|
||||
"description": None,
|
||||
"name": "notify_error",
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
|
||||
assert (
|
||||
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||
)
|
||||
assert responses == response_snapshot
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
|
@ -268,16 +268,6 @@ files = [
|
|||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "1.4.4"
|
||||
description = "Simple library for color and formatting to terminal"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "6.1.3"
|
||||
|
@ -855,18 +845,17 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
|||
|
||||
[[package]]
|
||||
name = "syrupy"
|
||||
version = "4.0.1"
|
||||
version = "4.7.1"
|
||||
description = "Pytest Snapshot Test Utility"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4"
|
||||
python-versions = ">=3.8.1"
|
||||
files = [
|
||||
{file = "syrupy-4.0.1-py3-none-any.whl", hash = "sha256:53d3107cc5e18a5def189c721879cea2cdafdee34b879f602133ca08837d0e4b"},
|
||||
{file = "syrupy-4.0.1.tar.gz", hash = "sha256:60e3e94782444e0f978cd3b207de32f6da3199b15a2db32eab02f83cebb63ae8"},
|
||||
{file = "syrupy-4.7.1-py3-none-any.whl", hash = "sha256:be002267a512a4bedddfae2e026c93df1ea928ae10baadc09640516923376d41"},
|
||||
{file = "syrupy-4.7.1.tar.gz", hash = "sha256:f9d4485f3f27d0e5df6ed299cac6fa32eb40a441915d988e82be5a4bdda335c8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colored = ">=1.3.92,<2.0.0"
|
||||
pytest = ">=7.0.0,<8.0.0"
|
||||
pytest = ">=7.0.0,<9.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "text-generation"
|
||||
|
@ -1049,4 +1038,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "421fbce065cb1499c666599cf0fd83a5ce8fb3bed09e83c16c3a3d6953b34026"
|
||||
content-hash = "f5c65e704b02250d73055cd04efcc22f8fc36144eddfc447a71c3a061748db80"
|
||||
|
|
|
@ -7,7 +7,7 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
|||
[tool.poetry.dependencies]
|
||||
pydantic = "> 2, < 3"
|
||||
python = ">=3.9,<3.13"
|
||||
syrupy = "4.0.1"
|
||||
syrupy = "^4.7.1"
|
||||
text-generation = "^0.6.0"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
|
|
@ -6,7 +6,6 @@ attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
colored==1.4.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
|
||||
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -25,7 +24,7 @@ pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
|||
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -8,7 +8,7 @@ use nix::unistd::Pid;
|
|||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::io::{BufRead, BufReader, Lines};
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||
use std::path::Path;
|
||||
use std::process::{Child, Command, ExitStatus, Stdio};
|
||||
|
@ -18,23 +18,134 @@ use std::sync::{mpsc, Arc};
|
|||
use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
use std::{
|
||||
fs, io,
|
||||
io::{Read, Write},
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||
|
||||
mod env_runtime;
|
||||
|
||||
fn get_config(
|
||||
model_id: &str,
|
||||
revision: &Option<String>,
|
||||
) -> Result<Config, Box<dyn std::error::Error>> {
|
||||
let mut path = std::path::Path::new(model_id).to_path_buf();
|
||||
let model_id = model_id.to_string();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
|
||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
// env variable has precedence over on file token.
|
||||
ApiBuilder::new().with_token(Some(token)).build()?
|
||||
} else {
|
||||
Api::new()?
|
||||
};
|
||||
let repo = if let Some(ref revision) = revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
|
||||
let config: Config = config.into();
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||
if let Some(config) = config {
|
||||
if prefix_caching.is_none() {
|
||||
if config.vision_config.is_some() {
|
||||
tracing::info!("Disabling prefix caching because of VLM model");
|
||||
prefix_caching = Some("0".to_string());
|
||||
} else if config.is_encoder_decoder {
|
||||
tracing::info!("Disabling prefix caching because of seq2seq model");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
match config.head_dim {
|
||||
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
match config.model_type.as_deref() {
|
||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||
// Required because gemma2 needs bfloat16 which is not supported by
|
||||
// flashinfer ?
|
||||
if attention.is_none() {
|
||||
tracing::info!(
|
||||
"Forcing flash decoding because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
Some("t5") => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if attention.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
if prefix_caching.is_none() {
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||
(prefix_caching, attention)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RawConfig {
|
||||
max_position_embeddings: Option<usize>,
|
||||
n_positions: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
max_seq_len: Option<usize>,
|
||||
quantization_config: Option<QuantizationConfig>,
|
||||
n_embd: Option<usize>,
|
||||
hidden_size: Option<usize>,
|
||||
num_attention_heads: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QuantizationConfig {
|
||||
quant_method: Option<Quantization>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct VisionConfig {}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
head_dim: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: bool,
|
||||
}
|
||||
|
||||
impl From<RawConfig> for Config {
|
||||
|
@ -43,13 +154,39 @@ impl From<RawConfig> for Config {
|
|||
.max_position_embeddings
|
||||
.or(other.max_seq_len)
|
||||
.or(other.n_positions);
|
||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||
let head_dim = other.head_dim.or_else(|| {
|
||||
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
||||
(Some(hidden_size), _, Some(num_attention_heads))
|
||||
if hidden_size % num_attention_heads == 0 =>
|
||||
{
|
||||
Some(hidden_size / num_attention_heads)
|
||||
}
|
||||
// Legacy
|
||||
(_, Some(hidden_size), Some(num_attention_heads))
|
||||
if hidden_size % num_attention_heads == 0 =>
|
||||
{
|
||||
Some(hidden_size / num_attention_heads)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
let model_type = other.model_type;
|
||||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
vision_config,
|
||||
is_encoder_decoder,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
enum Quantization {
|
||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||
/// <https://hf.co/models?search=awq>.
|
||||
|
@ -72,17 +209,17 @@ enum Quantization {
|
|||
Marlin,
|
||||
/// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,
|
||||
/// but it is known that the model will be much slower to run than the native f16.
|
||||
#[deprecated(
|
||||
since = "1.1.0",
|
||||
note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
||||
)]
|
||||
// #[deprecated(
|
||||
// since = "1.1.0",
|
||||
// note = "Use `eetq` instead, which provides better latencies overall and is drop-in in most cases"
|
||||
// )]
|
||||
Bitsandbytes,
|
||||
/// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,
|
||||
/// but it is known that the model will be much slower to run than the native f16.
|
||||
BitsandbytesNF4,
|
||||
BitsandbytesNf4,
|
||||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||
/// perplexity performance for you model
|
||||
BitsandbytesFP4,
|
||||
BitsandbytesFp4,
|
||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||
/// This dtype has native ops should be the fastest if available.
|
||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||
|
@ -99,10 +236,10 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
Quantization::BitsandbytesNF4 => {
|
||||
Quantization::BitsandbytesNf4 => {
|
||||
write!(f, "bitsandbytes-nf4")
|
||||
}
|
||||
Quantization::BitsandbytesFP4 => {
|
||||
Quantization::BitsandbytesFp4 => {
|
||||
write!(f, "bitsandbytes-fp4")
|
||||
}
|
||||
Quantization::Exl2 => {
|
||||
|
@ -721,6 +858,7 @@ fn shard_manager(
|
|||
.args(shard_args)
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.process_group(0)
|
||||
|
@ -742,12 +880,13 @@ fn shard_manager(
|
|||
};
|
||||
|
||||
// Redirect STDOUT to the console
|
||||
let mut pstdin = p.stdin.take().unwrap();
|
||||
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
||||
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||
|
||||
//stdout tracing thread
|
||||
thread::spawn(move || {
|
||||
log_lines(shard_stdout_reader.lines());
|
||||
log_lines(shard_stdout_reader);
|
||||
});
|
||||
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||
let (err_sender, err_receiver) = mpsc::channel();
|
||||
|
@ -756,6 +895,18 @@ fn shard_manager(
|
|||
err_sender.send(line).unwrap_or(());
|
||||
}
|
||||
});
|
||||
// We read stdin in another thread as it seems that lines() can block in some cases
|
||||
thread::spawn(move || {
|
||||
let mut stdin = io::stdin(); // We get `Stdin` here.
|
||||
loop {
|
||||
let mut buffer = vec![0; 4096];
|
||||
if let Ok(n) = stdin.read(&mut buffer) {
|
||||
if n > 0 {
|
||||
let _ = pstdin.write_all(&buffer[..n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut ready = false;
|
||||
let start_time = Instant::now();
|
||||
|
@ -862,19 +1013,36 @@ impl PythonLogMessage {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&String> for PythonLogMessage {
|
||||
impl TryFrom<&[u8]> for PythonLogMessage {
|
||||
type Error = serde_json::Error;
|
||||
|
||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
||||
serde_json::from_str::<Self>(value)
|
||||
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice::<Self>(value)
|
||||
}
|
||||
}
|
||||
|
||||
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
|
||||
for line in lines.map_while(Result::ok) {
|
||||
match PythonLogMessage::try_from(&line) {
|
||||
Ok(log) => log.trace(),
|
||||
Err(_) => tracing::debug!("{line}"),
|
||||
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||
let mut buffer = vec![0u8; 8 * 4096];
|
||||
let mut stdout = std::io::stdout();
|
||||
loop {
|
||||
let n = bufread.read(&mut buffer);
|
||||
if let Ok(n) = n {
|
||||
if n > 0 {
|
||||
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
|
||||
while let Some(line) = lines.next() {
|
||||
match PythonLogMessage::try_from(line) {
|
||||
Ok(log) => log.trace(),
|
||||
// For interactive debugging ?
|
||||
Err(_) => {
|
||||
stdout.write_all(line).unwrap();
|
||||
if lines.peek().is_some() {
|
||||
stdout.write_all(b"\n").unwrap();
|
||||
}
|
||||
stdout.flush().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1034,7 +1202,7 @@ fn download_convert_model(
|
|||
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
||||
|
||||
thread::spawn(move || {
|
||||
log_lines(download_stdout.lines());
|
||||
log_lines(download_stdout);
|
||||
});
|
||||
|
||||
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
||||
|
@ -1085,6 +1253,7 @@ fn spawn_shards(
|
|||
cuda_graphs: Vec<usize>,
|
||||
max_total_tokens: usize,
|
||||
max_input_tokens: usize,
|
||||
quantize: Option<Quantization>,
|
||||
max_log_level: LevelFilter,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
|
@ -1106,7 +1275,6 @@ fn spawn_shards(
|
|||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||
let otlp_service_name = args.otlp_service_name.clone();
|
||||
let quantize = args.quantize;
|
||||
let speculate = args.speculate;
|
||||
let dtype = args.dtype;
|
||||
let trust_remote_code = args.trust_remote_code;
|
||||
|
@ -1429,45 +1597,12 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
tracing::info!("{:#?}", args);
|
||||
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
|
||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||
// env variable has precedence over on file token.
|
||||
ApiBuilder::new().with_token(Some(token)).build()?
|
||||
} else {
|
||||
Api::new()?
|
||||
};
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: RawConfig = serde_json::from_str(&content)?;
|
||||
|
||||
if config.model_type == Some("gemma2".to_string()) {
|
||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
||||
std::env::set_var("FLASH_DECODING", "1");
|
||||
}
|
||||
let config: Config = config.into();
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = if let Some(config) = &config {
|
||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
|
@ -1477,17 +1612,20 @@ fn main() -> Result<(), LauncherError> {
|
|||
{
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
}
|
||||
Ok(max_default)
|
||||
max_default
|
||||
} else {
|
||||
Ok(max_position_embeddings)
|
||||
max_position_embeddings
|
||||
}
|
||||
} else {
|
||||
Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)))
|
||||
max_default
|
||||
}
|
||||
} else {
|
||||
max_default
|
||||
};
|
||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("ATTENTION", attention);
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
|
@ -1544,18 +1682,26 @@ fn main() -> Result<(), LauncherError> {
|
|||
)));
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
|
||||
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
|
||||
}
|
||||
let quantize = args.quantize.or(quantize);
|
||||
let cuda_graphs = match (&args.cuda_graphs, &quantize) {
|
||||
(Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
Some(
|
||||
Quantization::Bitsandbytes
|
||||
| Quantization::BitsandbytesNF4
|
||||
| Quantization::BitsandbytesFP4,
|
||||
| Quantization::BitsandbytesNf4
|
||||
| Quantization::BitsandbytesFp4,
|
||||
),
|
||||
) => {
|
||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
tracing::warn!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
(None, Some(Quantization::Exl2)) => {
|
||||
tracing::warn!("Exl2 doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
_ => {
|
||||
|
@ -1672,6 +1818,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
cuda_graphs,
|
||||
max_total_tokens,
|
||||
max_input_tokens,
|
||||
quantize,
|
||||
max_log_level,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
|
|
|
@ -33,13 +33,13 @@ export function get_options() {
|
|||
// rate: 20,
|
||||
// timeUnit: '1s',
|
||||
// },
|
||||
load_test: {
|
||||
executor: 'constant-arrival-rate',
|
||||
duration: '60s',
|
||||
preAllocatedVUs: 100,
|
||||
rate: 1,
|
||||
timeUnit: '1s',
|
||||
},
|
||||
// load_test: {
|
||||
// executor: 'constant-arrival-rate',
|
||||
// duration: '60s',
|
||||
// preAllocatedVUs: 100,
|
||||
// rate: 1,
|
||||
// timeUnit: '1s',
|
||||
// },
|
||||
// breakpoint: {
|
||||
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
|
||||
// preAllocatedVUs: 300,
|
||||
|
@ -47,12 +47,12 @@ export function get_options() {
|
|||
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
|
||||
// ],
|
||||
// },
|
||||
// throughput: {
|
||||
// executor: 'shared-iterations',
|
||||
// vus: 100,
|
||||
// iterations: 200,
|
||||
// maxDuration: '40s',
|
||||
// },
|
||||
throughput: {
|
||||
executor: 'shared-iterations',
|
||||
vus: 100,
|
||||
iterations: 200,
|
||||
maxDuration: '40s',
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
{ pkgs, nix-filter }:
|
||||
|
||||
let
|
||||
filter = nix-filter.lib;
|
||||
in
|
||||
with pkgs;
|
||||
defaultCrateOverrides
|
||||
// {
|
||||
aws-lc-rs = attrs: {
|
||||
# aws-lc-rs does its own custom parsing of Cargo environment
|
||||
# variables like DEP_.*_INCLUDE. However buildRustCrate does
|
||||
# not use the version number, so the parsing fails.
|
||||
postPatch = ''
|
||||
substituteInPlace build.rs \
|
||||
--replace-fail \
|
||||
"assert!(!selected.is_empty()" \
|
||||
"// assert!(!selected.is_empty()"
|
||||
'';
|
||||
};
|
||||
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
||||
|
||||
grpc-metadata = attrs: {
|
||||
src = filter {
|
||||
root = ../backends/grpc-metadata;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(matchExt "rs")
|
||||
];
|
||||
};
|
||||
};
|
||||
text-generation-benchmark = attrs: {
|
||||
src = filter {
|
||||
root = ../benchmark;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(matchExt "rs")
|
||||
];
|
||||
};
|
||||
};
|
||||
text-generation-client = attrs: {
|
||||
src = filter {
|
||||
root = ../.;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(and (inDirectory "backends/client") (matchExt "rs"))
|
||||
(and (inDirectory "proto") (matchExt "proto"))
|
||||
];
|
||||
};
|
||||
postPatch = "cd backends/client";
|
||||
buildInputs = [ protobuf ];
|
||||
};
|
||||
text-generation-launcher = attrs: {
|
||||
src = filter {
|
||||
root = ../launcher;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(matchExt "rs")
|
||||
];
|
||||
};
|
||||
};
|
||||
text-generation-router = attrs: {
|
||||
src = filter {
|
||||
root = ../router;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(matchExt "rs")
|
||||
];
|
||||
};
|
||||
};
|
||||
text-generation-router-v3 = attrs: {
|
||||
# We need to do the src/source root dance so that the build
|
||||
# has access to the protobuf file.
|
||||
src = filter {
|
||||
root = ../.;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(and (inDirectory "backends/v3") (matchExt "rs"))
|
||||
(and (inDirectory "proto") (matchExt "proto"))
|
||||
];
|
||||
};
|
||||
postPatch = "cd backends/v3";
|
||||
buildInputs = [ protobuf ];
|
||||
};
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
{
|
||||
nix-filter,
|
||||
buildPythonPackage,
|
||||
poetry-core,
|
||||
mypy-protobuf,
|
||||
awq-inference-engine,
|
||||
causal-conv1d,
|
||||
eetq,
|
||||
einops,
|
||||
exllamav2,
|
||||
fbgemm-gpu,
|
||||
flashinfer,
|
||||
flash-attn,
|
||||
flash-attn-layer-norm,
|
||||
flash-attn-rotary,
|
||||
grpc-interceptor,
|
||||
grpcio-reflection,
|
||||
grpcio-status,
|
||||
grpcio-tools,
|
||||
hf-transfer,
|
||||
loguru,
|
||||
mamba-ssm,
|
||||
marlin-kernels,
|
||||
opentelemetry-api,
|
||||
opentelemetry-exporter-otlp,
|
||||
opentelemetry-instrumentation-grpc,
|
||||
opentelemetry-semantic-conventions,
|
||||
peft,
|
||||
punica-kernels,
|
||||
safetensors,
|
||||
tokenizers,
|
||||
torch,
|
||||
sentencepiece,
|
||||
transformers,
|
||||
typer,
|
||||
vllm,
|
||||
}:
|
||||
|
||||
let
|
||||
filter = nix-filter.lib;
|
||||
in
|
||||
buildPythonPackage {
|
||||
name = "text-generation-server";
|
||||
|
||||
src = filter {
|
||||
root = ../.;
|
||||
include = with filter; [
|
||||
isDirectory
|
||||
(and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi")))
|
||||
"server/pyproject.toml"
|
||||
(and (inDirectory "proto/v3") (matchExt "proto"))
|
||||
];
|
||||
};
|
||||
|
||||
pyproject = true;
|
||||
|
||||
build-system = [ poetry-core ];
|
||||
|
||||
nativeBuildInputs = [ mypy-protobuf ];
|
||||
|
||||
pythonRelaxDeps = [
|
||||
"einops"
|
||||
"huggingface-hub"
|
||||
"loguru"
|
||||
"opentelemetry-instrumentation-grpc"
|
||||
"sentencepiece"
|
||||
"typer"
|
||||
];
|
||||
|
||||
pythonRemoveDeps = [ "scipy" ];
|
||||
|
||||
dependencies = [
|
||||
awq-inference-engine
|
||||
eetq
|
||||
causal-conv1d
|
||||
einops
|
||||
exllamav2
|
||||
fbgemm-gpu
|
||||
flashinfer
|
||||
flash-attn
|
||||
flash-attn-layer-norm
|
||||
flash-attn-rotary
|
||||
grpc-interceptor
|
||||
grpcio-reflection
|
||||
grpcio-status
|
||||
grpcio-tools
|
||||
hf-transfer
|
||||
loguru
|
||||
mamba-ssm
|
||||
marlin-kernels
|
||||
opentelemetry-api
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-grpc
|
||||
opentelemetry-semantic-conventions
|
||||
peft
|
||||
punica-kernels
|
||||
safetensors
|
||||
sentencepiece
|
||||
tokenizers
|
||||
transformers
|
||||
typer
|
||||
vllm
|
||||
];
|
||||
|
||||
prePatch = ''
|
||||
python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \
|
||||
--grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto
|
||||
find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch server/text_generation_server/pb/__init__.py
|
||||
cd server
|
||||
'';
|
||||
}
|
|
@ -3,22 +3,23 @@ syntax = "proto3";
|
|||
package generate.v3;
|
||||
|
||||
service TextGenerationService {
|
||||
/// Model Info
|
||||
rpc Info (InfoRequest) returns (InfoResponse) {}
|
||||
/// Service discovery
|
||||
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Remove requests from a cached batch
|
||||
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
|
||||
/// Warmup the model and compute max cache size
|
||||
rpc Warmup (WarmupRequest) returns (WarmupResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
/// Health check
|
||||
rpc Health (HealthRequest) returns (HealthResponse);
|
||||
/// Model Info
|
||||
rpc Info(InfoRequest) returns (InfoResponse) {}
|
||||
/// Service discovery
|
||||
rpc ServiceDiscovery(ServiceDiscoveryRequest)
|
||||
returns (ServiceDiscoveryResponse) {}
|
||||
/// Empties batch cache
|
||||
rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);
|
||||
/// Remove requests from a cached batch
|
||||
rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);
|
||||
/// Warmup the model and compute max cache size
|
||||
rpc Warmup(WarmupRequest) returns (WarmupResponse);
|
||||
/// Prefill batch and decode first token
|
||||
rpc Prefill(PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode(DecodeRequest) returns (DecodeResponse);
|
||||
/// Health check
|
||||
rpc Health(HealthRequest) returns (HealthResponse);
|
||||
}
|
||||
|
||||
message HealthRequest {}
|
||||
|
@ -28,240 +29,241 @@ message HealthResponse {}
|
|||
message InfoRequest {}
|
||||
|
||||
message InfoResponse {
|
||||
bool requires_padding = 1;
|
||||
string dtype = 2;
|
||||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
uint32 speculate = 5;
|
||||
bool requires_padding = 1;
|
||||
string dtype = 2;
|
||||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
uint32 speculate = 5;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
message ServiceDiscoveryRequest {}
|
||||
|
||||
message ServiceDiscoveryResponse {
|
||||
/// Other shards urls
|
||||
repeated string urls = 1;
|
||||
/// Other shards urls
|
||||
repeated string urls = 1;
|
||||
}
|
||||
|
||||
message ClearCacheRequest {
|
||||
/// Optional batch id
|
||||
optional uint64 id = 1;
|
||||
/// Optional batch id
|
||||
optional uint64 id = 1;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
message ClearCacheResponse {}
|
||||
|
||||
message Image {
|
||||
/// Binary image data.
|
||||
bytes data = 1;
|
||||
/// Binary image data.
|
||||
bytes data = 1;
|
||||
|
||||
/// Image MIME type.
|
||||
string mimetype = 2;
|
||||
/// Image MIME type.
|
||||
string mimetype = 2;
|
||||
}
|
||||
|
||||
message InputChunk {
|
||||
oneof chunk {
|
||||
/// Plain text data
|
||||
string text = 1;
|
||||
/// Image data
|
||||
Image image = 2;
|
||||
}
|
||||
oneof chunk {
|
||||
/// Plain text data
|
||||
string text = 1;
|
||||
/// Image data
|
||||
Image image = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message Input {
|
||||
repeated InputChunk chunks = 1;
|
||||
}
|
||||
message Input { repeated InputChunk chunks = 1; }
|
||||
|
||||
enum GrammarType {
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
GRAMMAR_TYPE_JSON = 1;
|
||||
GRAMMAR_TYPE_REGEX = 2;
|
||||
GRAMMAR_TYPE_NONE = 0;
|
||||
GRAMMAR_TYPE_JSON = 1;
|
||||
GRAMMAR_TYPE_REGEX = 2;
|
||||
}
|
||||
|
||||
message NextTokenChooserParameters {
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
/// restricting to the k highest probability elements
|
||||
uint32 top_k = 2;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float top_p = 3;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float typical_p = 4;
|
||||
/// apply sampling on the logits
|
||||
bool do_sample = 5;
|
||||
/// random seed for sampling
|
||||
uint64 seed = 6;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 7;
|
||||
/// frequency penalty
|
||||
float frequency_penalty = 9;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
/// grammar (applied if not empty)
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
/// exponential scaling output probability distribution
|
||||
float temperature = 1;
|
||||
/// restricting to the k highest probability elements
|
||||
uint32 top_k = 2;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float top_p = 3;
|
||||
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
float typical_p = 4;
|
||||
/// apply sampling on the logits
|
||||
bool do_sample = 5;
|
||||
/// random seed for sampling
|
||||
uint64 seed = 6;
|
||||
/// repetition penalty
|
||||
float repetition_penalty = 7;
|
||||
/// frequency penalty
|
||||
float frequency_penalty = 9;
|
||||
/// token watermarking using "A Watermark for Large Language Models"
|
||||
bool watermark = 8;
|
||||
/// grammar (applied if not empty)
|
||||
string grammar = 10;
|
||||
/// grammar type
|
||||
GrammarType grammar_type = 11;
|
||||
}
|
||||
|
||||
message StoppingCriteriaParameters {
|
||||
/// Maximum number of generated tokens
|
||||
uint32 max_new_tokens = 1;
|
||||
/// Optional stopping sequences
|
||||
repeated string stop_sequences = 2;
|
||||
/// Ignore end of sequence token
|
||||
/// used for benchmarking
|
||||
bool ignore_eos_token = 3;
|
||||
/// Maximum number of generated tokens
|
||||
uint32 max_new_tokens = 1;
|
||||
/// Optional stopping sequences
|
||||
repeated string stop_sequences = 2;
|
||||
/// Ignore end of sequence token
|
||||
/// used for benchmarking
|
||||
bool ignore_eos_token = 3;
|
||||
}
|
||||
|
||||
message Request {
|
||||
/// Request ID
|
||||
uint64 id = 1;
|
||||
/// The generation context as chunks
|
||||
Input input_chunks = 8;
|
||||
/// The generation context, stringified input_chunks
|
||||
string inputs = 2;
|
||||
/// Context truncation
|
||||
uint32 truncate = 3;
|
||||
/// Next Token Chooser Parameters
|
||||
NextTokenChooserParameters parameters = 4;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
/// Paged attention blocks
|
||||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
/// Request ID
|
||||
uint64 id = 1;
|
||||
/// The generation context as chunks
|
||||
Input input_chunks = 8;
|
||||
/// The generation context, stringified input_chunks
|
||||
string inputs = 2;
|
||||
/// Context truncation
|
||||
uint32 truncate = 3;
|
||||
/// Next Token Chooser Parameters
|
||||
NextTokenChooserParameters parameters = 4;
|
||||
/// Stopping Criteria Parameters
|
||||
StoppingCriteriaParameters stopping_parameters = 5;
|
||||
/// Return prefill logprobs
|
||||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
/// Paged attention blocks
|
||||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
/// Prefix length that can be retrieved from the KV cache.
|
||||
uint32 prefix_len = 12;
|
||||
/// Context truncation
|
||||
bool add_special_tokens = 13;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests
|
||||
repeated Request requests = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
/// Maximum number of Paged Attention blocks
|
||||
uint32 max_blocks = 5;
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests
|
||||
repeated Request requests = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
/// Maximum number of Paged Attention blocks
|
||||
uint32 max_blocks = 5;
|
||||
}
|
||||
|
||||
message CachedBatch {
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests ids
|
||||
repeated uint64 request_ids = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
/// Batch ID
|
||||
uint64 id = 1;
|
||||
/// Individual requests ids
|
||||
repeated uint64 request_ids = 2;
|
||||
/// Batch size (==len(requests))
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
FINISH_REASON_LENGTH = 0;
|
||||
FINISH_REASON_EOS_TOKEN = 1;
|
||||
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||
FINISH_REASON_LENGTH = 0;
|
||||
FINISH_REASON_EOS_TOKEN = 1;
|
||||
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||
}
|
||||
|
||||
message GeneratedText {
|
||||
/// Output
|
||||
string text = 1;
|
||||
/// Number of generated tokens
|
||||
uint32 generated_tokens = 2;
|
||||
/// Finish reason
|
||||
FinishReason finish_reason = 3;
|
||||
/// Seed
|
||||
optional uint64 seed = 4;
|
||||
/// Output
|
||||
string text = 1;
|
||||
/// Number of generated tokens
|
||||
uint32 generated_tokens = 2;
|
||||
/// Finish reason
|
||||
FinishReason finish_reason = 3;
|
||||
/// Seed
|
||||
optional uint64 seed = 4;
|
||||
}
|
||||
|
||||
message Tokens {
|
||||
/// Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// tokens
|
||||
repeated string texts = 3;
|
||||
/// special
|
||||
repeated bool is_special = 4;
|
||||
/// Token IDs
|
||||
repeated uint32 ids = 1;
|
||||
/// Logprobs
|
||||
repeated float logprobs = 2;
|
||||
/// tokens
|
||||
repeated string texts = 3;
|
||||
/// special
|
||||
repeated bool is_special = 4;
|
||||
}
|
||||
|
||||
message Generation {
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
/// Prefill tokens (optional)
|
||||
Tokens prefill_tokens = 2;
|
||||
Tokens tokens = 3;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 4;
|
||||
/// Top tokens
|
||||
repeated Tokens top_tokens = 5;
|
||||
/// Request ID
|
||||
uint64 request_id = 1;
|
||||
/// Prefill tokens (optional)
|
||||
Tokens prefill_tokens = 2;
|
||||
Tokens tokens = 3;
|
||||
/// Complete generated text
|
||||
optional GeneratedText generated_text = 4;
|
||||
/// Top tokens
|
||||
repeated Tokens top_tokens = 5;
|
||||
}
|
||||
|
||||
message FilterBatchRequest {
|
||||
/// Batch ID
|
||||
uint64 batch_id = 1;
|
||||
/// Requests to keep
|
||||
repeated uint64 request_ids = 2;
|
||||
/// Batch ID
|
||||
uint64 batch_id = 1;
|
||||
/// Requests to keep
|
||||
repeated uint64 request_ids = 2;
|
||||
}
|
||||
|
||||
message FilterBatchResponse {
|
||||
/// Filtered Batch (cached)
|
||||
CachedBatch batch = 1;
|
||||
/// Filtered Batch (cached)
|
||||
CachedBatch batch = 1;
|
||||
}
|
||||
|
||||
|
||||
message PrefillRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
}
|
||||
|
||||
message PrefillResponse {
|
||||
/// Generation
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Generation
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
}
|
||||
|
||||
message DecodeRequest {
|
||||
/// Cached batches
|
||||
repeated CachedBatch batches = 1;
|
||||
/// Cached batches
|
||||
repeated CachedBatch batches = 1;
|
||||
}
|
||||
|
||||
message DecodeResponse {
|
||||
/// Decodes
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Concatenate elapsed time in nanoseconds
|
||||
optional uint64 concat_ns = 6;
|
||||
/// Decodes
|
||||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional CachedBatch batch = 2;
|
||||
/// Forward elapsed time in nanoseconds
|
||||
uint64 forward_ns = 3;
|
||||
/// Decode elapsed time in nanoseconds
|
||||
uint64 decode_ns = 4;
|
||||
/// Total elapsed time in nanoseconds
|
||||
uint64 total_ns = 5;
|
||||
/// Concatenate elapsed time in nanoseconds
|
||||
optional uint64 concat_ns = 6;
|
||||
}
|
||||
|
||||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
uint32 max_input_length = 2;
|
||||
uint32 max_prefill_tokens = 3;
|
||||
uint32 max_total_tokens = 4;
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
uint32 max_input_length = 2;
|
||||
uint32 max_prefill_tokens = 3;
|
||||
uint32 max_total_tokens = 4;
|
||||
}
|
||||
|
||||
message WarmupResponse {
|
||||
/// Maximum number of tokens supported by the model
|
||||
optional uint32 max_supported_total_tokens = 1;
|
||||
/// Maximum number of tokens supported by the model
|
||||
optional uint32 max_supported_total_tokens = 1;
|
||||
}
|
||||
|
|
|
@ -27,8 +27,14 @@ reqwest = { version = "0.11.20", features = [] }
|
|||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.40"
|
||||
|
@ -37,16 +43,22 @@ tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
|
|||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
sysinfo = "0.30.13"
|
||||
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||
uuid = { version = "1.9.1", default-features = false, features = [
|
||||
"v4",
|
||||
"fast-rng",
|
||||
"macro-diagnostics",
|
||||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
|
||||
|
|
|
@ -153,6 +153,7 @@ pub enum Config {
|
|||
Bloom,
|
||||
Mpt,
|
||||
Gpt2,
|
||||
Gptj,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
use crate::infer::InferError;
|
||||
use crate::{
|
||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
||||
};
|
||||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
|
@ -16,6 +15,7 @@ pub(crate) struct ChatTemplate {
|
|||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
variables: HashSet<String>,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
|
@ -29,48 +29,70 @@ impl ChatTemplate {
|
|||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
tracing::debug!("Loading template: {:#?}", template_str);
|
||||
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
// get the list of variables that are used in the template
|
||||
let variables = template.undeclared_variables(true);
|
||||
// check if the `tools` variable is used in the template
|
||||
let use_default_tool_template = !variables.contains("tools");
|
||||
tracing::debug!("Use default tool template: {}", use_default_tool_template);
|
||||
|
||||
Self {
|
||||
template,
|
||||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||
use_default_tool_template,
|
||||
variables,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply(
|
||||
&self,
|
||||
guideline: Option<&str>,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
});
|
||||
}
|
||||
}
|
||||
// check if guideline is expected but not provided
|
||||
if self.variables.contains("guideline") && guideline.is_none() {
|
||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||
}
|
||||
|
||||
let tools = match tools_and_prompt {
|
||||
Some((tools, tool_prompt)) => {
|
||||
// check if the `tools` variable is used in the template
|
||||
// if not, we need to append the tools to the last message
|
||||
let text = if self.use_default_tool_template {
|
||||
match serde_json::to_string(&tools) {
|
||||
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||
}
|
||||
} else {
|
||||
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||
format!("\n---\n{}", tool_prompt)
|
||||
};
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
last_message.content.push(MessageChunk::Text { text });
|
||||
}
|
||||
Some(tools)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
guideline,
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
tools,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
|
@ -80,7 +102,10 @@ impl ChatTemplate {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::infer::chat_template::raise_exception;
|
||||
use crate::{ChatTemplateInputs, TextMessage};
|
||||
use crate::infer::ChatTemplate;
|
||||
use crate::{
|
||||
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
||||
};
|
||||
use minijinja::Environment;
|
||||
|
||||
#[test]
|
||||
|
@ -731,6 +756,19 @@ mod tests {
|
|||
},
|
||||
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "google/shieldgemma-9b",
|
||||
chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n",
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
guideline: Some("Do not use offensive language."),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n<start_of_turn>\nHuman Question: I'd like to show off how chat templating works!\n<end_of_turn>\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n",
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
|
@ -755,4 +793,116 @@ mod tests {
|
|||
assert_eq!(result, target);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_invalid_with_guideline() {
|
||||
let ct = ChatTemplate::new(
|
||||
"{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n<end_of_turn>\\n\" }}\n {{- \"<start_of_turn>\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n<end_of_turn>\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(),
|
||||
Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||
Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||
);
|
||||
|
||||
// convert TextMessage to Message
|
||||
let msgs: Vec<Message> = vec![
|
||||
Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"I'd like to show off how chat templating works!".to_string(),
|
||||
),
|
||||
},
|
||||
Message {
|
||||
name: None,
|
||||
role: "assistant".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"I'm doing great. How can I help you today?".to_string(),
|
||||
),
|
||||
},
|
||||
Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText("Hello, how are you?".to_string()),
|
||||
},
|
||||
];
|
||||
|
||||
let result = ct.apply(None, msgs, None);
|
||||
|
||||
match result {
|
||||
Ok(_) => panic!("Should have failed since no guideline is provided"),
|
||||
Err(e) => {
|
||||
assert_eq!(e.to_string(), "Missing template vatiable: guideline")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_default_tool_template() {
|
||||
let ct = ChatTemplate::new(
|
||||
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
|
||||
Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||
Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||
);
|
||||
|
||||
// convert TextMessage to Message
|
||||
let msgs: Vec<Message> = vec![
|
||||
Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"I'd like to show off how chat templating works!".to_string(),
|
||||
),
|
||||
},
|
||||
Message {
|
||||
name: None,
|
||||
role: "assistant".to_string(),
|
||||
content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
|
||||
},
|
||||
Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText("Just testing".to_string()),
|
||||
},
|
||||
];
|
||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_with_custom_tool_template() {
|
||||
// chat template from meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
let ct = ChatTemplate::new(
|
||||
"{{- bos_token }}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n".to_string(),
|
||||
Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||
Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||
);
|
||||
let msgs: Vec<Message> = vec![
|
||||
Message {
|
||||
name: None,
|
||||
role: "system".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"Youre a helpful assistant! Answer the users question best you can."
|
||||
.to_string(),
|
||||
),
|
||||
},
|
||||
Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"What is the weather like in Brooklyn, New York?".to_string(),
|
||||
),
|
||||
},
|
||||
];
|
||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ mod chat_template;
|
|||
pub mod tool_grammar;
|
||||
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::GrammarType;
|
||||
use crate::Tool;
|
||||
use crate::{
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
|
@ -120,10 +120,11 @@ impl Infer {
|
|||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||
// Tokenize request
|
||||
let inputs = request.inputs;
|
||||
let add_special_tokens = request.add_special_tokens;
|
||||
let truncate = request.parameters.truncate;
|
||||
let encoding = self
|
||||
.validation
|
||||
.tokenize(inputs, truncate)
|
||||
.tokenize(inputs, add_special_tokens, truncate)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
tracing::error!("Tokenization {err}");
|
||||
|
@ -138,13 +139,14 @@ impl Infer {
|
|||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(messages, grammar_with_prompt)
|
||||
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||
.map_err(|e| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||
tracing::error!("{e}");
|
||||
|
@ -336,6 +338,8 @@ pub enum InferError {
|
|||
IncompleteGeneration,
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
#[error("Missing template vatiable: {0}")]
|
||||
MissingTemplateVariable(String),
|
||||
#[error("Tool error: {0}")]
|
||||
ToolError(String),
|
||||
}
|
||||
|
@ -348,6 +352,7 @@ impl InferError {
|
|||
InferError::ValidationError(_) => "validation",
|
||||
InferError::IncompleteGeneration => "incomplete_generation",
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||
InferError::ToolError(_) => "tool_error",
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
use crate::infer::InferError;
|
||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
||||
use crate::{
|
||||
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||
ToolType,
|
||||
};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
@ -16,17 +19,38 @@ impl ToolGrammar {
|
|||
}
|
||||
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tools: Vec<Tool>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||
// if no tools are provided, we return None
|
||||
let tools = match tools {
|
||||
Some(tools) if !tools.is_empty() => tools,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
if tools.is_empty() {
|
||||
return Ok((tools, None));
|
||||
}
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
let mut tools = tools.clone();
|
||||
|
||||
// add the notify_error function to the tools
|
||||
let notify_error = Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "notify_error".to_string(),
|
||||
description: Some("Notify an error or issue".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"error": {
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}
|
||||
},
|
||||
"required": ["error"]
|
||||
}),
|
||||
},
|
||||
};
|
||||
tools.push(notify_error);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
|
@ -35,87 +59,57 @@ impl ToolGrammar {
|
|||
ToolType::Function { function } => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools,
|
||||
ToolType::NoTool => return Ok(None),
|
||||
ToolType::OneOf => tools.clone(),
|
||||
ToolType::NoTool => return Ok((tools, None)),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
let mut params = Map::new();
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
Value::String(func.description.unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
let mut properties = Map::new();
|
||||
let mut required = vec![Value::String("_name".to_string())];
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
if let Value::Object(args) = func.arguments {
|
||||
if let Some(Value::Object(props)) = args.get("properties") {
|
||||
properties.extend(props.clone());
|
||||
}
|
||||
if let Some(Value::Array(reqs)) = args.get("required") {
|
||||
required.extend(reqs.clone());
|
||||
}
|
||||
params.insert(
|
||||
"additionalProperties".to_string(),
|
||||
Value::Bool(
|
||||
args.get("additionalProperties").and_then(|v| v.as_str())
|
||||
== Some("true"),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
params.insert("properties".to_string(), Value::Object(properties));
|
||||
params.insert("required".to_string(), Value::Array(required));
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
let tool_schema = JsonSchemaTool {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
|
@ -123,13 +117,10 @@ impl ToolGrammar {
|
|||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Some(tools))
|
||||
Ok((tools, Some(tool_schema)))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -205,6 +205,13 @@ impl State {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(max_size) = max_size {
|
||||
if max_size == 0 {
|
||||
tracing::debug!("No capacity");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v2::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
Attention, Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use crate::{Attention, FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
|
@ -40,12 +40,18 @@ impl BackendV2 {
|
|||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.expect(&format!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
false
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
|
@ -161,8 +167,8 @@ pub(crate) async fn batching_task(
|
|||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
|
|
|
@ -15,14 +15,60 @@ use tracing::warn;
|
|||
use utoipa::ToSchema;
|
||||
use validation::Validation;
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum Attention {
|
||||
Paged,
|
||||
FlashDecoding,
|
||||
FlashInfer,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn block_size(&self) -> u32 {
|
||||
match self {
|
||||
Attention::FlashDecoding => 256,
|
||||
Attention::FlashInfer => 1,
|
||||
Attention::Paged => 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParseError;
|
||||
|
||||
impl std::fmt::Display for ParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Cannot parse attention value")
|
||||
}
|
||||
}
|
||||
impl std::error::Error for ParseError {}
|
||||
|
||||
impl std::str::FromStr for Attention {
|
||||
type Err = ParseError;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"paged" => Ok(Attention::Paged),
|
||||
"flashdecoding" => Ok(Attention::FlashDecoding),
|
||||
"flashinfer" => Ok(Attention::FlashInfer),
|
||||
_ => Err(ParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexInstance {
|
||||
pub(crate) struct GenerateVertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub inputs: String,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
enum VertexInstance {
|
||||
Generate(GenerateVertexInstance),
|
||||
Chat(ChatRequest),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexRequest {
|
||||
#[serde(rename = "instances")]
|
||||
|
@ -619,7 +665,7 @@ impl ChatCompletion {
|
|||
message,
|
||||
logprobs: return_logprobs
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
finish_reason: details.finish_reason.format(true),
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens: details.prefill.len() as u32,
|
||||
|
@ -811,10 +857,10 @@ pub(crate) struct ChatRequest {
|
|||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A prompt to be appended before the tools
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
#[serde(default)]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
||||
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
|
@ -829,12 +875,15 @@ pub(crate) struct ChatRequest {
|
|||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub response_format: Option<GrammarType>,
|
||||
|
||||
/// A guideline to be used in the chat_template
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub guideline: Option<String>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
||||
)
|
||||
pub fn default_tool_prompt() -> String {
|
||||
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||
|
@ -876,7 +925,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
|
|||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||
pub struct Tools {
|
||||
pub struct JsonSchemaTool {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
|
@ -934,8 +983,8 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||
bos_token: Option<&'a str>,
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
tools: Option<&'a str>,
|
||||
tools_prompt: Option<&'a str>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
guideline: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
|
||||
|
@ -981,8 +1030,10 @@ impl MessageContent {
|
|||
pub fn push(&mut self, chunk: MessageChunk) {
|
||||
match self {
|
||||
MessageContent::SingleText(text) => {
|
||||
*self =
|
||||
MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]);
|
||||
*self = MessageContent::MultipleChunks(vec![
|
||||
MessageChunk::Text { text: text.clone() },
|
||||
chunk,
|
||||
]);
|
||||
}
|
||||
MessageContent::MultipleChunks(chunks) => {
|
||||
chunks.push(chunk);
|
||||
|
@ -1038,6 +1089,16 @@ pub(crate) struct GenerateRequest {
|
|||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
|
||||
/// This is used internally because some requests
|
||||
/// already contain the templated input therefore
|
||||
/// we shouldn't add the special tokens.
|
||||
#[serde(default = "default_true", skip)]
|
||||
pub add_special_tokens: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
|
@ -1055,6 +1116,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
|||
fn from(req: CompatGenerateRequest) -> Self {
|
||||
Self {
|
||||
inputs: req.inputs,
|
||||
add_special_tokens: true,
|
||||
parameters: req.parameters,
|
||||
}
|
||||
}
|
||||
|
@ -1117,6 +1179,15 @@ impl std::fmt::Display for FinishReason {
|
|||
}
|
||||
}
|
||||
|
||||
impl FinishReason {
|
||||
pub fn format(&self, use_stop: bool) -> String {
|
||||
match self {
|
||||
FinishReason::EndOfSequenceToken if use_stop => "stop".to_string(),
|
||||
_ => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct BestOfSequence {
|
||||
#[schema(example = "test")]
|
||||
|
@ -1157,6 +1228,12 @@ pub(crate) struct GenerateResponse {
|
|||
pub details: Option<Details>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub(crate) struct ChatTokenizeResponse {
|
||||
pub(crate) tokenize_response: TokenizeResponse,
|
||||
pub(crate) templated_text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
#[serde(transparent)]
|
||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||
|
@ -1169,6 +1246,8 @@ pub(crate) struct StreamDetails {
|
|||
pub generated_tokens: u32,
|
||||
#[schema(nullable = true, example = 42)]
|
||||
pub seed: Option<u64>,
|
||||
#[schema(example = 1)]
|
||||
pub input_length: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -1189,6 +1268,34 @@ pub(crate) struct ErrorResponse {
|
|||
pub error_type: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema)]
|
||||
pub(crate) struct ModelInfo {
|
||||
#[schema(example = "gpt2")]
|
||||
pub id: String,
|
||||
#[schema(example = "model")]
|
||||
pub object: String,
|
||||
#[schema(example = 1686935002)]
|
||||
pub created: u64,
|
||||
#[schema(example = "openai")]
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema)]
|
||||
pub(crate) struct ModelsInfo {
|
||||
#[schema(example = "list")]
|
||||
pub object: String,
|
||||
pub data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
impl Default for ModelsInfo {
|
||||
fn default() -> Self {
|
||||
ModelsInfo {
|
||||
object: "list".to_string(),
|
||||
data: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1296,6 +1403,35 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_content_append() {
|
||||
let mut content = MessageContent::SingleText("Initial text".to_string());
|
||||
let chunk = MessageChunk::Text {
|
||||
text: "Additional text".to_string(),
|
||||
};
|
||||
|
||||
content.push(chunk);
|
||||
|
||||
match content {
|
||||
MessageContent::MultipleChunks(chunks) => {
|
||||
assert_eq!(chunks.len(), 2);
|
||||
assert_eq!(
|
||||
chunks[0],
|
||||
MessageChunk::Text {
|
||||
text: "Initial text".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
chunks[1],
|
||||
MessageChunk::Text {
|
||||
text: "Additional text".to_string()
|
||||
}
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected MultipleChunks, but got a different variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_request() {
|
||||
let json = json!({
|
||||
|
|
|
@ -8,6 +8,7 @@ use crate::kserve::{
|
|||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
|
@ -23,6 +24,7 @@ use crate::{
|
|||
VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use crate::{ModelInfo, ModelsInfo};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
|
@ -115,6 +117,133 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||
Json(info.0)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/v1/models",
|
||||
responses(
|
||||
(status = 200, description = "Served model info", body = ModelInfo),
|
||||
(status = 404, description = "Model not found", body = ErrorResponse),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(info))]
|
||||
/// Get model info
|
||||
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
||||
Json(ModelsInfo {
|
||||
data: vec![ModelInfo {
|
||||
id: info.0.model_id.clone(),
|
||||
object: "model".to_string(),
|
||||
created: 0, // TODO: determine how to get this
|
||||
owned_by: info.0.model_id.clone(),
|
||||
}],
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/chat_tokenize",
|
||||
request_body = ChatRequest,
|
||||
responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse))
|
||||
)]
|
||||
async fn get_chat_tokenize(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
..
|
||||
} = req;
|
||||
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
let generate_request = GenerateRequest {
|
||||
inputs,
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty: None,
|
||||
frequency_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
max_new_tokens: max_tokens,
|
||||
return_full_text: None,
|
||||
stop: stop.unwrap_or_default(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: _grammar,
|
||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
};
|
||||
|
||||
let input = generate_request.inputs.clone();
|
||||
let encoding = infer.tokenize(generate_request).await?;
|
||||
if let Some(encoding) = encoding {
|
||||
let tokens: Vec<SimpleToken> = encoding
|
||||
.get_ids()
|
||||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
start,
|
||||
stop,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let resp = ChatTokenizeResponse {
|
||||
tokenize_response: TokenizeResponse(tokens),
|
||||
templated_text: input,
|
||||
};
|
||||
Ok((HeaderMap::new(), Json(resp)))
|
||||
} else {
|
||||
Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ErrorResponse {
|
||||
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
||||
error_type: "no fast tokenizer".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
|
@ -429,7 +558,7 @@ async fn generate_stream_internal(
|
|||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, _input_length, response_stream)) => {
|
||||
Ok((_permit, input_length, response_stream)) => {
|
||||
let mut index = 0;
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
// Server-Sent Event stream
|
||||
|
@ -472,6 +601,7 @@ async fn generate_stream_internal(
|
|||
finish_reason: generated_text.finish_reason,
|
||||
generated_tokens: generated_text.generated_tokens,
|
||||
seed: generated_text.seed,
|
||||
input_length,
|
||||
}),
|
||||
false => None,
|
||||
};
|
||||
|
@ -649,6 +779,7 @@ async fn completions(
|
|||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
|
@ -697,21 +828,46 @@ async fn completions(
|
|||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(Completion::Chunk(Chunk {
|
||||
id: "".to_string(),
|
||||
created: current_time,
|
||||
let message = match stream_token.details {
|
||||
Some(details) => {
|
||||
let completion_tokens = details.generated_tokens;
|
||||
let prompt_tokens = details.input_length;
|
||||
let total_tokens = prompt_tokens + completion_tokens;
|
||||
|
||||
Completion::Final(CompletionFinal {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
usage: Usage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
None => Completion::Chunk(Chunk {
|
||||
id: String::new(),
|
||||
created: current_time,
|
||||
choices: vec![CompletionComplete {
|
||||
finish_reason: "".to_string(),
|
||||
finish_reason: String::new(),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: stream_token.token.text,
|
||||
}],
|
||||
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
}))
|
||||
}),
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(message)
|
||||
.unwrap_or_else(|_e| Event::default())
|
||||
};
|
||||
|
||||
|
@ -919,7 +1075,7 @@ async fn completions(
|
|||
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||
|
||||
Ok(CompletionComplete {
|
||||
finish_reason: details.finish_reason.to_string(),
|
||||
finish_reason: details.finish_reason.format(true),
|
||||
index: index as u32,
|
||||
logprobs: None,
|
||||
text: generation.generated_text,
|
||||
|
@ -1021,80 +1177,36 @@ async fn chat_completions(
|
|||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
..
|
||||
} = req;
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
|
||||
// response_format and tools are mutually exclusive
|
||||
if response_format.is_some() && tools.as_ref().is_some() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Grammar and tools are mutually exclusive".to_string(),
|
||||
error_type: "grammar and tools".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// determine the appropriate arguments for apply_chat_template
|
||||
let tools_grammar_prompt = tool_grammar
|
||||
.as_ref()
|
||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
||||
|
||||
let (tools_grammar_prompt, grammar) = match response_format {
|
||||
Some(response_format) => (None, Some(response_format)),
|
||||
None => (
|
||||
tools_grammar_prompt.clone(),
|
||||
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
||||
),
|
||||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: err.to_string(),
|
||||
error_type: err.error_type().to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
)?;
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
|
@ -1138,7 +1250,7 @@ async fn chat_completions(
|
|||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||
let (content, tool_calls) = if using_tools {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
let content = if !stream_token.token.special {
|
||||
|
@ -1159,7 +1271,7 @@ async fn chat_completions(
|
|||
tool_calls,
|
||||
current_time,
|
||||
logprobs,
|
||||
stream_token.details.map(|d| d.finish_reason.to_string()),
|
||||
stream_token.details.map(|d| d.finish_reason.format(true)),
|
||||
),
|
||||
))
|
||||
.unwrap_or_else(|e| {
|
||||
|
@ -1192,10 +1304,14 @@ async fn chat_completions(
|
|||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||
let gen_text_value: Value = serde_json::from_str(&generation.generated_text)
|
||||
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||
|
||||
let (tool_calls, output) = if using_tools {
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
InferError::ToolError(format!(
|
||||
"Failed to parse generated text: {} {:?}",
|
||||
e, generation.generated_text
|
||||
))
|
||||
})?;
|
||||
let function = gen_text_value.get("function").ok_or(InferError::ToolError(
|
||||
"No function found in generated text".to_string(),
|
||||
))?;
|
||||
|
@ -1290,13 +1406,14 @@ async fn vertex_compatibility(
|
|||
));
|
||||
}
|
||||
|
||||
// Process all instances
|
||||
let predictions = req
|
||||
.instances
|
||||
.iter()
|
||||
.map(|instance| {
|
||||
let generate_request = GenerateRequest {
|
||||
// Prepare futures for all instances
|
||||
let mut futures = Vec::with_capacity(req.instances.len());
|
||||
|
||||
for instance in req.instances.iter() {
|
||||
let generate_request = match instance {
|
||||
VertexInstance::Generate(instance) => GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
|
@ -1305,31 +1422,117 @@ async fn vertex_compatibility(
|
|||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
},
|
||||
VertexInstance::Chat(instance) => {
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
..
|
||||
} = instance.clone();
|
||||
|
||||
async {
|
||||
generate_internal(
|
||||
Extension(infer.clone()),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
span.clone(),
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to prepare chat input: {}", e),
|
||||
error_type: "Input preparation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k: None,
|
||||
top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
futures.push(async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||
let results = futures::future::join_all(futures).await;
|
||||
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||
let predictions = predictions?;
|
||||
|
||||
let response = VertexResponse { predictions };
|
||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||
|
@ -1360,8 +1563,11 @@ async fn tokenize(
|
|||
.iter()
|
||||
.zip(encoding.get_offsets())
|
||||
.map(|(&id, &(start, stop))| {
|
||||
let text: String =
|
||||
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
|
||||
let text = input
|
||||
.chars()
|
||||
.skip(start)
|
||||
.take(stop - start)
|
||||
.collect::<String>();
|
||||
SimpleToken {
|
||||
id,
|
||||
text,
|
||||
|
@ -1409,6 +1615,7 @@ chat_completions,
|
|||
completions,
|
||||
tokenize,
|
||||
metrics,
|
||||
openai_get_model_info,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
|
@ -1461,6 +1668,7 @@ ToolCall,
|
|||
Function,
|
||||
FunctionDefinition,
|
||||
ToolChoice,
|
||||
ModelInfo,
|
||||
)
|
||||
),
|
||||
tags(
|
||||
|
@ -1913,6 +2121,120 @@ async fn start(
|
|||
.install_recorder()
|
||||
.expect("failed to install metrics recorder");
|
||||
|
||||
// Metrics descriptions
|
||||
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Request duration"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_validation_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Request validation duration"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_queue_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Request queue duration"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_inference_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Request inference duration"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_mean_time_per_token_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Mean time per token per request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_generated_tokens",
|
||||
metrics::Unit::Count,
|
||||
"Generated tokens per request"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"tgi_batch_inference_count",
|
||||
metrics::Unit::Count,
|
||||
"Inference calls per method (prefill or decode)"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"tgi_request_count",
|
||||
metrics::Unit::Count,
|
||||
"Total number of requests"
|
||||
);
|
||||
metrics::describe_counter!(
|
||||
"tgi_batch_inference_success",
|
||||
metrics::Unit::Count,
|
||||
"Number of successful inference calls per method (prefill or decode)"
|
||||
);
|
||||
metrics::describe_gauge!(
|
||||
"tgi_batch_current_size",
|
||||
metrics::Unit::Count,
|
||||
"Current batch size"
|
||||
);
|
||||
metrics::describe_gauge!("tgi_queue_size", metrics::Unit::Count, "Current queue size");
|
||||
metrics::describe_gauge!(
|
||||
"tgi_batch_current_max_tokens",
|
||||
metrics::Unit::Count,
|
||||
"Maximum tokens for the current batch"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_max_new_tokens",
|
||||
metrics::Unit::Count,
|
||||
"Maximum new tokens per request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_batch_inference_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Batch inference duration"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_batch_forward_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Batch forward duration per method (prefill or decode)"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_skipped_tokens",
|
||||
metrics::Unit::Count,
|
||||
"Speculated tokens per request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_batch_filter_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Time spent filtering batches and sending generated tokens per method (prefill or decode)"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_queue_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Time spent in the queue per request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_validation_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Time spent validating the request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Total time spent processing the request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_batch_decode_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"Time spent decoding a batch per method (prefill or decode)"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_request_input_length",
|
||||
metrics::Unit::Count,
|
||||
"Input token length per request"
|
||||
);
|
||||
metrics::describe_histogram!(
|
||||
"tgi_batch_next_size",
|
||||
metrics::Unit::Count,
|
||||
"Batch size of the next batch"
|
||||
);
|
||||
|
||||
// CORS layer
|
||||
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
|
||||
let cors_layer = CorsLayer::new()
|
||||
|
@ -2036,10 +2358,12 @@ async fn start(
|
|||
}
|
||||
let info_routes = Router::new()
|
||||
.route("/", get(health))
|
||||
.route("/chat_tokenize", post(get_chat_tokenize))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
.route("/metrics", get(metrics));
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/v1/models", get(openai_get_model_info));
|
||||
|
||||
// Conditional AWS Sagemaker route
|
||||
let aws_sagemaker_route = if messages_api_enabled {
|
||||
|
@ -2232,6 +2556,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
};
|
||||
|
||||
|
@ -2332,3 +2657,157 @@ fn create_post_processor(
|
|||
|
||||
Ok(post_processor)
|
||||
}
|
||||
|
||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||
|
||||
fn prepare_chat_input(
|
||||
infer: &Infer,
|
||||
response_format: Option<GrammarType>,
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ToolChoice,
|
||||
tool_prompt: &str,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<PreparedInput, InferError> {
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err(InferError::ToolError(
|
||||
"Grammar and tools are mutually exclusive".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||
if let Some(format) = response_format {
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
return Ok((inputs, Some(format), false));
|
||||
}
|
||||
|
||||
// when no response_format is set and tools are included, apply the chat template with the tools
|
||||
// to generate inputs
|
||||
if let Some(tools) = tools {
|
||||
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||
|
||||
let grammar = tool_schema
|
||||
.as_ref()
|
||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt.into())),
|
||||
)?;
|
||||
return Ok((inputs, grammar, tool_schema.is_some()));
|
||||
}
|
||||
|
||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
Ok((inputs, None, false))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ChatTemplateVersions;
|
||||
use crate::HubTokenizerConfig;
|
||||
use crate::TokenizerConfigToken;
|
||||
use crate::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_prepare_chat_input() {
|
||||
// Mock Backend to avoid network requests
|
||||
struct MockBackend;
|
||||
|
||||
impl Backend for MockBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
_request: crate::validation::ValidGenerateRequest,
|
||||
) -> Result<
|
||||
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||
Result<InferStreamResponse, InferError>,
|
||||
>,
|
||||
InferError,
|
||||
> {
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
fn health<'a, 'async_trait>(
|
||||
&'a self,
|
||||
_current_health: bool,
|
||||
) -> core::pin::Pin<
|
||||
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||
>
|
||||
where
|
||||
'a: 'async_trait,
|
||||
Self: 'async_trait,
|
||||
{
|
||||
unimplemented!("Never called in this test");
|
||||
}
|
||||
}
|
||||
|
||||
let backend = MockBackend {};
|
||||
|
||||
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||
|
||||
// mock tokenizer config values
|
||||
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||
tokenizer_config.chat_template = Some(
|
||||
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||
);
|
||||
|
||||
let infer = Infer::new(
|
||||
backend,
|
||||
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||
1,
|
||||
tokenizer_config,
|
||||
HubProcessorConfig::default(),
|
||||
);
|
||||
let response_format = None;
|
||||
let tools = Some(vec![Tool {
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
name: "get_current_weather".to_string(),
|
||||
description: Some("Get the current weather".to_string()),
|
||||
arguments: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}),
|
||||
},
|
||||
}]);
|
||||
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
|
||||
let guideline = None;
|
||||
let messages = vec![Message {
|
||||
name: None,
|
||||
role: "user".to_string(),
|
||||
content: MessageContent::SingleText(
|
||||
"What is the weather like in New York?".to_string(),
|
||||
),
|
||||
}];
|
||||
|
||||
let result = prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
ToolChoice(None),
|
||||
tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||
assert_eq!(using_tools, true);
|
||||
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ use rand::{thread_rng, Rng};
|
|||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc;
|
||||
|
@ -94,6 +95,7 @@ impl Validation {
|
|||
pub async fn tokenize(
|
||||
&self,
|
||||
inputs: String,
|
||||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
|
@ -103,7 +105,11 @@ impl Validation {
|
|||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
sender
|
||||
.send(((inputs, truncate), response_sender, Span::current()))
|
||||
.send((
|
||||
(inputs, add_special_tokens, truncate),
|
||||
response_sender,
|
||||
Span::current(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Await on response channel
|
||||
|
@ -115,15 +121,20 @@ impl Validation {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
#[instrument(skip(self, inputs))]
|
||||
async fn validate_input(
|
||||
&self,
|
||||
inputs: String,
|
||||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
if let Some((encoding, inputs)) = self
|
||||
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||
.await?
|
||||
{
|
||||
// Create response channel
|
||||
let input_length = if let Some(truncate) = truncate {
|
||||
std::cmp::min(encoding.len(), truncate)
|
||||
|
@ -156,8 +167,11 @@ impl Validation {
|
|||
));
|
||||
}
|
||||
|
||||
let ids = encoding.get_ids();
|
||||
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
|
@ -180,7 +194,12 @@ impl Validation {
|
|||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
}
|
||||
|
||||
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
||||
Ok((
|
||||
vec![Chunk::Text(inputs)],
|
||||
None,
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,8 +333,13 @@ impl Validation {
|
|||
.unwrap_or(Ok(None))?;
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_length, max_new_tokens) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||
.validate_input(
|
||||
request.inputs,
|
||||
request.add_special_tokens,
|
||||
truncate,
|
||||
max_new_tokens,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
||||
|
@ -391,6 +415,8 @@ impl Validation {
|
|||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
input_ids: input_ids.map(Arc::new),
|
||||
add_special_tokens: request.add_special_tokens,
|
||||
decoder_input_details,
|
||||
input_length: input_length as u32,
|
||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||
|
@ -439,12 +465,15 @@ fn tokenizer_worker(
|
|||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||
receiver.blocking_recv()
|
||||
{
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
add_special_tokens,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
|
@ -581,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
|
|||
fn prepare_input(
|
||||
inputs: String,
|
||||
_truncate: Option<usize>,
|
||||
add_special_tokens: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
config: Option<&Config>,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
|
@ -618,14 +648,14 @@ fn prepare_input(
|
|||
|
||||
// Get the number of tokens in the input
|
||||
let encoding = tokenizer
|
||||
.encode(tokenizer_query, true)
|
||||
.encode(tokenizer_query, add_special_tokens)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, input_chunks))
|
||||
}
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
(String, bool, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
@ -707,8 +737,10 @@ pub struct ValidStoppingParameters {
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct ValidGenerateRequest {
|
||||
pub inputs: Vec<Chunk>,
|
||||
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub add_special_tokens: bool,
|
||||
pub decoder_input_details: bool,
|
||||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
|
@ -815,11 +847,11 @@ mod tests {
|
|||
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
Ok((_s, 0, 10)) => (),
|
||||
Ok((_s, _, 0, 10)) => (),
|
||||
r => panic!("Unexpected not max new tokens: {r:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -850,7 +882,7 @@ mod tests {
|
|||
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
|
@ -884,6 +916,7 @@ mod tests {
|
|||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
best_of: Some(2),
|
||||
do_sample: false,
|
||||
|
@ -923,6 +956,7 @@ mod tests {
|
|||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(1.0),
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -938,6 +972,7 @@ mod tests {
|
|||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -953,6 +988,7 @@ mod tests {
|
|||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_p: None,
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -991,6 +1027,7 @@ mod tests {
|
|||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(5),
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -1006,6 +1043,7 @@ mod tests {
|
|||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(4),
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -1018,6 +1056,7 @@ mod tests {
|
|||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(0),
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -1030,6 +1069,7 @@ mod tests {
|
|||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
max_new_tokens: Some(5),
|
||||
|
@ -1078,6 +1118,7 @@ mod tests {
|
|||
let chunks = match validation
|
||||
.tokenize(
|
||||
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
@ -1137,6 +1178,7 @@ mod tests {
|
|||
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
|
||||
PIXEL_GIF, PIXEL_GIF
|
||||
),
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[toolchain]
|
||||
# Released on: June 13, 2024
|
||||
# https://releases.rs/docs/1.79.0/
|
||||
channel = "1.79.0"
|
||||
channel = "1.80.0"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
|
|
@ -6,6 +6,8 @@ include Makefile-eetq
|
|||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
include Makefile-fbgemm
|
||||
include Makefile-exllamav2
|
||||
include Makefile-flashinfer
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
exllamav2_commit := v0.1.8
|
||||
|
||||
build-exllamav2:
|
||||
git clone https://github.com/turboderp/exllamav2.git exllamav2 && \
|
||||
cd exllamav2 && git fetch && git checkout $(exllamav2_commit) && \
|
||||
git submodule update --init --recursive && \
|
||||
pip install -r requirements.txt && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build
|
||||
|
||||
install-exllamav2: build-exllamav2
|
||||
cd exllamav2/ && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install
|
|
@ -1,7 +1,9 @@
|
|||
fbgemm_commit := ddac8dd9fc0bee70a3f456df68b8aac38576c856
|
||||
fbgemm_commit := v0.8.0
|
||||
|
||||
build-fbgemm:
|
||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
||||
@if [ ! -d "fbgemm" ]; then \
|
||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
|
||||
fi
|
||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
||||
git submodule update --init --recursive && \
|
||||
cd fbgemm_gpu && \
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
install-flashinfer:
|
||||
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4
|
|
@ -1,10 +1,7 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
import torch
|
||||
|
||||
extra_compile_args = ["-std=c++17"]
|
||||
if not torch.version.hip:
|
||||
extra_compile_args.append("-arch=compute_80")
|
||||
|
||||
setup(
|
||||
name="custom_kernels",
|
||||
|
|
|
@ -1070,6 +1070,30 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
|||
[package.extras]
|
||||
dev = ["Sphinx (>=4.1.1)", "black (>=19.10b0)", "colorama (>=0.3.4)", "docutils (==0.16)", "flake8 (>=3.7.7)", "isort (>=5.1.1)", "pytest (>=4.6.2)", "pytest-cov (>=2.7.1)", "sphinx-autobuild (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "tox (>=3.9.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mdurl = ">=0.1,<1.0"
|
||||
|
||||
[package.extras]
|
||||
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||
linkify = ["linkify-it-py (>=1,<3)"]
|
||||
plugins = ["mdit-py-plugins"]
|
||||
profiling = ["gprof2dot"]
|
||||
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||
|
||||
[[package]]
|
||||
name = "markupsafe"
|
||||
version = "2.1.5"
|
||||
|
@ -1207,6 +1231,17 @@ torch = "*"
|
|||
type = "url"
|
||||
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
description = "Markdown URL utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
@ -2277,6 +2312,20 @@ files = [
|
|||
[package.dependencies]
|
||||
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.18.0"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"},
|
||||
{file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.4"
|
||||
|
@ -2508,6 +2557,24 @@ urllib3 = ">=1.21.1,<3"
|
|||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.7.1"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
|
||||
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
markdown-it-py = ">=2.2.0"
|
||||
pygments = ">=2.13.0,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.19.0"
|
||||
|
@ -3170,11 +3237,6 @@ files = [
|
|||
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
|
||||
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
|
||||
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
|
||||
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
|
||||
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
|
||||
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
|
||||
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
|
||||
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3584,4 +3646,4 @@ torch = ["torch"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "a89867b23017d2efa8a7aa14d4764bcbd3b4dea9bfbf06a7a68464cb184ac6a1"
|
||||
content-hash = "0ff7a244a409b616490cb238995bbe28dedf67ccb8855edafa2b71ee2e777dbd"
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue