Merge branch 'main' into gpt_awq_4
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
commit
10628e878a
|
@ -0,0 +1,41 @@
|
|||
name: "Nix Tests"
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/nix_tests.yaml"
|
||||
- "server/**"
|
||||
- "proto/**"
|
||||
- "router/**"
|
||||
- "launcher/**"
|
||||
- "Cargo.lock"
|
||||
- "rust-toolchain.toml"
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on:
|
||||
group: aws-highmemory-32-plus-priv
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: cachix/install-nix-action@v27
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
- uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: text-generation-inference
|
||||
# If you chose signing key for write access
|
||||
authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'
|
||||
env:
|
||||
USER: github_runner
|
||||
- name: Build
|
||||
run: nix develop .#test --command echo "Ok"
|
||||
- name: Pre-commit tests.
|
||||
run: nix develop .#test --command pre-commit run --all-files
|
||||
- name: Python tests.
|
||||
run: nix develop .#test --command python -m pytest server/tests/
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
- name: Rust tests.
|
||||
run: nix develop .#test --command cargo test
|
|
@ -17,19 +17,15 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
run_tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
SCCACHE_GHA_ENABLED: "on"
|
||||
RUSTC_WRAPPER: /usr/local/bin/sccache
|
||||
SCCACHE: 0.3.3
|
||||
|
||||
runs-on:
|
||||
group: aws-highmemory-32-plus-priv
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v1
|
||||
uses: actions/setup-python@v4
|
||||
id: python
|
||||
with:
|
||||
python-version: 3.9
|
||||
python-version: 3.11
|
||||
- name: Install Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
|
@ -44,30 +40,9 @@ jobs:
|
|||
run: |
|
||||
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
|
||||
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
||||
- name: Install sccache
|
||||
run: |
|
||||
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
|
||||
chmod +x /usr/local/bin/sccache
|
||||
- name: configure sccache
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || '');
|
||||
core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');
|
||||
core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}');
|
||||
core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-');
|
||||
- name: cargo registry cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-
|
||||
cargo-${{ runner.os }}-
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
- name: Install
|
||||
run: |
|
||||
sudo apt install python3.11-dev -y
|
||||
make install-cpu
|
||||
- name: Run server tests
|
||||
run: |
|
||||
|
@ -82,6 +57,3 @@ jobs:
|
|||
- name: Run Rust tests
|
||||
run: |
|
||||
cargo test
|
||||
- name: sccache stats
|
||||
run: |
|
||||
/usr/local/bin/sccache --show-stats
|
||||
|
|
|
@ -2118,6 +2118,15 @@ version = "2.7.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.23.0"
|
||||
|
@ -3112,6 +3121,69 @@ dependencies = [
|
|||
"prost 0.12.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset",
|
||||
"once_cell",
|
||||
"portable-atomic",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qoi"
|
||||
version = "0.4.1"
|
||||
|
@ -4068,7 +4140,7 @@ dependencies = [
|
|||
"pkg-config",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokenizers 0.19.1",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
|
@ -4091,7 +4163,7 @@ dependencies = [
|
|||
"tabled",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokenizers 0.20.0",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
|
@ -4161,6 +4233,7 @@ dependencies = [
|
|||
"once_cell",
|
||||
"opentelemetry 0.20.0",
|
||||
"opentelemetry-otlp",
|
||||
"pyo3",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
|
@ -4168,7 +4241,7 @@ dependencies = [
|
|||
"serde_json",
|
||||
"sysinfo",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokenizers 0.20.0",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
|
@ -4219,7 +4292,7 @@ dependencies = [
|
|||
"slotmap",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokenizers 0.20.0",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic 0.10.2",
|
||||
|
@ -4374,6 +4447,39 @@ dependencies = [
|
|||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8a24d7f7d6be5b9d1377418b893ab1808af0074f5d1bb2c64784452ddd2aa70"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"derive_builder",
|
||||
"esaxx-rs",
|
||||
"getrandom",
|
||||
"hf-hub",
|
||||
"indicatif",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"macro_rules_attribute",
|
||||
"monostate",
|
||||
"onig",
|
||||
"paste",
|
||||
"rand",
|
||||
"rayon",
|
||||
"rayon-cond",
|
||||
"regex",
|
||||
"regex-syntax 0.8.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spm_precompiled",
|
||||
"thiserror",
|
||||
"unicode-normalization-alignments",
|
||||
"unicode-segmentation",
|
||||
"unicode_categories",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.39.3"
|
||||
|
@ -4839,6 +4945,12 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.7.1"
|
||||
|
|
|
@ -25,7 +25,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
|||
|
||||
[workspace.dependencies]
|
||||
base64 = "0.22.0"
|
||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||
tokenizers = { version = "0.20.0", features = ["http"] }
|
||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||
metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
|
|
40
Dockerfile
40
Dockerfile
|
@ -13,10 +13,13 @@ COPY benchmark benchmark
|
|||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
|
@ -37,6 +40,7 @@ COPY router router
|
|||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
|
@ -45,7 +49,7 @@ FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install
|
|||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||
ARG PYTORCH_VERSION=2.4.0
|
||||
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.11
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=12.4
|
||||
ARG MAMBA_VERSION=24.3.0-0
|
||||
|
@ -216,33 +220,33 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
COPY --from=pytorch-install /opt/conda /opt/conda
|
||||
|
||||
# Copy build artifacts from flash attention builder
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-v2-builder /opt/conda/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from exllama kernels builder
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from exllamav2 kernels builder
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from awq kernels builder
|
||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from eetq kernels builder
|
||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from lorax punica kernels builder
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from fbgemm builder
|
||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.11/cmake-install /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
# Copy build artifacts from mamba builder
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /opt/conda/lib/python3.11/site-packages
|
||||
COPY --from=flashinfer-builder /opt/conda/lib/python3.11/site-packages/flashinfer/ /opt/conda/lib/python3.11/site-packages/flashinfer/
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
@ -257,7 +261,9 @@ RUN cd server && \
|
|||
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
|
||||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
# Required to find libpython within the rust binaries
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||
# This is needed because exl2 tries to load flash-attn
|
||||
# And fails with our builds.
|
||||
ENV EXLLAMA_NO_FLASH_ATTN=1
|
||||
|
|
|
@ -17,6 +17,8 @@ RUN cargo chef prepare --recipe-path recipe.json
|
|||
|
||||
FROM chef AS builder
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
|
@ -64,14 +66,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3-dev && \
|
||||
python3.11-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTORCH_VERSION='2.3.0'
|
||||
ARG ROCM_VERSION='6.0.2'
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
@ -89,10 +91,18 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||
mamba init && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
# RUN conda install intel::mkl-static intel::mkl-include
|
||||
# Install pytorch
|
||||
# On arm64 we exit with an error code
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
# Install flash-attention, torch dependencies
|
||||
RUN pip install numpy einops ninja --no-cache-dir
|
||||
|
||||
RUN conda install intel::mkl-static intel::mkl-include
|
||||
RUN pip uninstall -y triton && \
|
||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||
cd triton/python && \
|
||||
|
@ -172,19 +182,19 @@ ENV HF_HOME=/data \
|
|||
PORT=80
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from flash attention v2 builder
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from exllama kernels builder
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from exllamav2 kernels builder
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
@ -201,6 +211,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/l
|
|||
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base AS sagemaker
|
||||
|
|
|
@ -18,6 +18,8 @@ RUN cargo chef prepare --recipe-path recipe.json
|
|||
|
||||
FROM chef AS builder
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
python3.11-dev
|
||||
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
||||
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
|
||||
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
|
||||
|
@ -42,9 +44,35 @@ RUN cargo build --profile release-opt
|
|||
|
||||
# Text Generation Inference base image for Intel
|
||||
|
||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
|
||||
FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu
|
||||
|
||||
USER root
|
||||
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
# translating Docker's TARGETPLATFORM into mamba arches
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||
*) MAMBA_ARCH=x86_64 ;; \
|
||||
esac && \
|
||||
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||
RUN chmod +x ~/mambaforge.sh && \
|
||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||
|
@ -54,7 +82,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
|||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||
|
||||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HF_HOME=/data \
|
||||
|
@ -63,9 +91,7 @@ ENV HF_HOME=/data \
|
|||
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
@ -80,14 +106,12 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
|||
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
||||
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
||||
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
|
||||
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||
|
||||
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
||||
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
|
@ -123,7 +147,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
|||
PORT=80
|
||||
|
||||
ARG MAMBA_VERSION=23.1.0-1
|
||||
ARG PYTHON_VERSION='3.10.10'
|
||||
ARG PYTHON_VERSION='3.11.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
@ -140,12 +164,19 @@ RUN chmod +x ~/mambaforge.sh && \
|
|||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||
rm ~/mambaforge.sh
|
||||
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") exit 1 ;; \
|
||||
*) /opt/conda/bin/conda update -y conda && \
|
||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
RUN conda install -c conda-forge gperftools mkl
|
||||
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install triton numa
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.5.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240815%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||
RUN pip install triton py-libnuma
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -156,10 +187,11 @@ RUN cd torch-ccl && git submodule sync && git submodule update --init --recursiv
|
|||
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
|
|
@ -376,10 +376,9 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
|
|
|
@ -357,6 +357,7 @@ impl State {
|
|||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
|
|
|
@ -123,8 +123,6 @@ impl Allocator for RadixAllocator {
|
|||
prefill_tokens: prefill_tokens.clone(),
|
||||
};
|
||||
|
||||
tracing::debug!("Blocks {blocks:?}");
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
|
||||
|
|
58
flake.lock
58
flake.lock
|
@ -492,6 +492,24 @@
|
|||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils_7": {
|
||||
"inputs": {
|
||||
"systems": "systems_7"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"gitignore": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
|
@ -700,16 +718,16 @@
|
|||
},
|
||||
"nixpkgs_6": {
|
||||
"locked": {
|
||||
"lastModified": 1723912943,
|
||||
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=",
|
||||
"owner": "danieldk",
|
||||
"lastModified": 1724915739,
|
||||
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c",
|
||||
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "cuda-12.4",
|
||||
"owner": "nixos",
|
||||
"ref": "nixos-unstable-small",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
|
@ -835,11 +853,11 @@
|
|||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1724638882,
|
||||
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
|
||||
"lastModified": 1726021481,
|
||||
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
|
||||
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -938,17 +956,33 @@
|
|||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems_7": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"tgi-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat_4",
|
||||
"flake-utils": "flake-utils_7",
|
||||
"nixpkgs": "nixpkgs_6"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1725011596,
|
||||
"narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=",
|
||||
"lastModified": 1725950569,
|
||||
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
|
||||
"owner": "danieldk",
|
||||
"repo": "tgi-nix",
|
||||
"rev": "717c2b07e38538abf05237cca65b2d1363c2c9af",
|
||||
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
48
flake.nix
48
flake.nix
|
@ -46,12 +46,30 @@
|
|||
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
|
||||
inherit crateOverrides;
|
||||
};
|
||||
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
|
||||
inherit crateOverrides;
|
||||
};
|
||||
router =
|
||||
let
|
||||
routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
|
||||
inherit crateOverrides;
|
||||
};
|
||||
packagePath =
|
||||
with pkgs.python3.pkgs;
|
||||
makePythonPath [
|
||||
protobuf
|
||||
sentencepiece
|
||||
torch
|
||||
transformers
|
||||
];
|
||||
in
|
||||
pkgs.writeShellApplication {
|
||||
name = "text-generation-router";
|
||||
text = ''
|
||||
PYTHONPATH="${packagePath}" ${routerUnwrapped}/bin/text-generation-router "$@"
|
||||
'';
|
||||
};
|
||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||
in
|
||||
{
|
||||
formatter = pkgs.nixfmt-rfc-style;
|
||||
devShells = with pkgs; rec {
|
||||
default = pure;
|
||||
|
||||
|
@ -63,6 +81,29 @@
|
|||
server
|
||||
];
|
||||
};
|
||||
test = mkShell {
|
||||
buildInputs =
|
||||
[
|
||||
# benchmark
|
||||
# launcher
|
||||
# router
|
||||
server
|
||||
openssl.dev
|
||||
pkg-config
|
||||
cargo
|
||||
rustfmt
|
||||
clippy
|
||||
]
|
||||
++ (with python3.pkgs; [
|
||||
docker
|
||||
pytest
|
||||
pytest-asyncio
|
||||
syrupy
|
||||
pre-commit
|
||||
ruff
|
||||
]);
|
||||
|
||||
};
|
||||
|
||||
impure = mkShell {
|
||||
buildInputs =
|
||||
|
@ -82,6 +123,7 @@
|
|||
docker
|
||||
pip
|
||||
ipdb
|
||||
click
|
||||
pyright
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
|
|
@ -19,6 +19,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
|||
from text_generation import AsyncClient
|
||||
from text_generation.types import (
|
||||
BestOfSequence,
|
||||
Message,
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionComplete,
|
||||
|
@ -97,25 +98,25 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
) -> bool:
|
||||
def convert_data(data):
|
||||
data = json.loads(data)
|
||||
if isinstance(data, Dict) and "choices" in data:
|
||||
choices = data["choices"]
|
||||
if isinstance(choices, List) and len(choices) >= 1:
|
||||
if "delta" in choices[0]:
|
||||
return ChatCompletionChunk(**data)
|
||||
if "text" in choices[0]:
|
||||
return Completion(**data)
|
||||
return ChatComplete(**data)
|
||||
return _convert_data(data)
|
||||
|
||||
def _convert_data(data):
|
||||
if isinstance(data, Dict):
|
||||
return Response(**data)
|
||||
if "choices" in data:
|
||||
data["choices"] = list(
|
||||
sorted(data["choices"], key=lambda x: x["index"])
|
||||
)
|
||||
choices = data["choices"]
|
||||
if isinstance(choices, List) and len(choices) >= 1:
|
||||
if "delta" in choices[0]:
|
||||
return ChatCompletionChunk(**data)
|
||||
if "text" in choices[0]:
|
||||
return Completion(**data)
|
||||
return ChatComplete(**data)
|
||||
else:
|
||||
return Response(**data)
|
||||
if isinstance(data, List):
|
||||
if (
|
||||
len(data) > 0
|
||||
and "object" in data[0]
|
||||
and data[0]["object"] == "text_completion"
|
||||
):
|
||||
return [Completion(**d) for d in data]
|
||||
return [Response(**d) for d in data]
|
||||
return [_convert_data(d) for d in data]
|
||||
raise NotImplementedError
|
||||
|
||||
def eq_token(token: Token, other: Token) -> bool:
|
||||
|
@ -571,3 +572,38 @@ def generate_load():
|
|||
return await asyncio.gather(*futures)
|
||||
|
||||
return generate_load_inner
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def generate_multi():
|
||||
async def generate_load_inner(
|
||||
client: AsyncClient,
|
||||
prompts: List[str],
|
||||
max_new_tokens: int,
|
||||
seed: Optional[int] = None,
|
||||
) -> List[Response]:
|
||||
|
||||
import numpy as np
|
||||
|
||||
arange = np.arange(len(prompts))
|
||||
perm = np.random.permutation(arange)
|
||||
rperm = [-1] * len(perm)
|
||||
for i, p in enumerate(perm):
|
||||
rperm[p] = i
|
||||
|
||||
shuffled_prompts = [prompts[p] for p in perm]
|
||||
futures = [
|
||||
client.chat(
|
||||
messages=[Message(role="user", content=prompt)],
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=0,
|
||||
seed=seed,
|
||||
)
|
||||
for prompt in shuffled_prompts
|
||||
]
|
||||
|
||||
shuffled_responses = await asyncio.gather(*futures)
|
||||
responses = [shuffled_responses[p] for p in rperm]
|
||||
return responses
|
||||
|
||||
return generate_load_inner
|
||||
|
|
|
@ -1,38 +1,38 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " A Beginner’s Guide\nDeep learning is a subset"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " PR for more information?"
|
||||
"text": " This is a question that has puzzled many people for"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "hd20220811-"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "le Business Incubator is providing a workspace"
|
||||
"text": "usculas_minusculas(s):\n \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " severely flawed and often has a substandard"
|
||||
"text": " Paris\nWhat is the capital of France?\nThe"
|
||||
}
|
||||
],
|
||||
"created": 1722014725,
|
||||
"created": 1725877154,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 36,
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 44
|
||||
"completion_tokens": 40,
|
||||
"prompt_tokens": 22,
|
||||
"total_tokens": 62
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
"text": " A"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -20,12 +20,72 @@
|
|||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
"text": " This"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " Paris"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "us"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " Beginner"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -38,9 +98,9 @@
|
|||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -50,12 +110,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "hd"
|
||||
"text": "cul"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -65,12 +125,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
"text": "’s"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -80,12 +140,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -95,12 +155,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
"text": "What"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -110,12 +170,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "aho"
|
||||
"text": "as"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -125,12 +185,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "2"
|
||||
"text": " Guide"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -140,252 +200,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "2"
|
||||
"text": " question"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "2"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "ima"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " Sarah"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " Yes"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " And"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "i"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": ","
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " what"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "'"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "s"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " Moh"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -398,9 +218,9 @@
|
|||
"text": " is"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -410,12 +230,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "m"
|
||||
"text": "_minus"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -425,12 +245,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " Room"
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -440,12 +260,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "s"
|
||||
"text": " that"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -458,9 +278,9 @@
|
|||
"text": " the"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -470,12 +290,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " tired"
|
||||
"text": "cul"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -485,12 +305,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": ":"
|
||||
"text": "Deep"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -500,12 +320,12 @@
|
|||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": "'"
|
||||
"text": " has"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -518,9 +338,9 @@
|
|||
"text": " capital"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -530,12 +350,192 @@
|
|||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": ","
|
||||
"text": "as"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " puzzled"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " of"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "(s"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " is"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " many"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " France"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": "):\n"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " people"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": "?\n"
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " "
|
||||
}
|
||||
],
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -545,12 +545,12 @@
|
|||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " She"
|
||||
"text": " subset"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -560,12 +560,12 @@
|
|||
"finish_reason": "length",
|
||||
"index": 1,
|
||||
"logprobs": null,
|
||||
"text": " scale"
|
||||
"text": " for"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -575,12 +575,12 @@
|
|||
"finish_reason": "length",
|
||||
"index": 2,
|
||||
"logprobs": null,
|
||||
"text": " of"
|
||||
"text": "The"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
},
|
||||
|
@ -590,12 +590,12 @@
|
|||
"finish_reason": "length",
|
||||
"index": 3,
|
||||
"logprobs": null,
|
||||
"text": " its"
|
||||
"text": " \"\"\"\n"
|
||||
}
|
||||
],
|
||||
"created": 1724833943,
|
||||
"created": 1725883643,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native"
|
||||
}
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": " PR for flake8"
|
||||
"text": " A Beginner’s Guide\nDeep learning is a subset"
|
||||
}
|
||||
],
|
||||
"created": 1713284454,
|
||||
"created": 1725876621,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "2.0.1-native",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 5,
|
||||
"completion_tokens": 10,
|
||||
"prompt_tokens": 6,
|
||||
"total_tokens": 11
|
||||
"total_tokens": 16
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -11,7 +11,7 @@ from text_generation.types import (
|
|||
@pytest.fixture(scope="module")
|
||||
def flash_llama_completion_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
@ -34,16 +34,19 @@ def test_flash_llama_completion_single_prompt(
|
|||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 5,
|
||||
"seed": 0,
|
||||
"prompt": "What is Deep Learning?",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
headers=flash_llama_completion.headers,
|
||||
stream=False,
|
||||
)
|
||||
response = response.json()
|
||||
assert len(response["choices"]) == 1
|
||||
|
||||
assert (
|
||||
response["choices"][0]["text"]
|
||||
== " A Beginner’s Guide\nDeep learning is a subset"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -53,9 +56,15 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||
f"{flash_llama_completion.base_url}/v1/completions",
|
||||
json={
|
||||
"model": "tgi",
|
||||
"prompt": ["Say", "this", "is", "a"],
|
||||
"prompt": [
|
||||
"What is Deep Learning?",
|
||||
"Is water wet?",
|
||||
"What is the capital of France?",
|
||||
"def mai",
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"seed": 0,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
headers=flash_llama_completion.headers,
|
||||
stream=False,
|
||||
|
@ -63,9 +72,16 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
|
|||
response = response.json()
|
||||
assert len(response["choices"]) == 4
|
||||
|
||||
all_indexes = [choice["index"] for choice in response["choices"]]
|
||||
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
|
||||
all_indexes.sort()
|
||||
assert all_indexes == [0, 1, 2, 3]
|
||||
all_indices, all_strings = zip(*all_indexes)
|
||||
assert list(all_indices) == [0, 1, 2, 3]
|
||||
assert list(all_strings) == [
|
||||
" A Beginner’s Guide\nDeep learning is a subset",
|
||||
" This is a question that has puzzled many people for",
|
||||
" Paris\nWhat is the capital of France?\nThe",
|
||||
'usculas_minusculas(s):\n """\n',
|
||||
]
|
||||
|
||||
assert response == response_snapshot
|
||||
|
||||
|
@ -77,19 +93,21 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||
request = {
|
||||
"model": "tgi",
|
||||
"prompt": [
|
||||
"What color is the sky?",
|
||||
"What is Deep Learning?",
|
||||
"Is water wet?",
|
||||
"What is the capital of France?",
|
||||
"def mai",
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"seed": 0,
|
||||
"temperature": 0.0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||||
|
||||
chunks = []
|
||||
strings = [""] * 4
|
||||
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||
async with session.post(url, json=request) as response:
|
||||
# iterate over the stream
|
||||
|
@ -108,7 +126,15 @@ async def test_flash_llama_completion_many_prompts_stream(
|
|||
for c in chunk:
|
||||
chunks.append(Completion(**c))
|
||||
assert "choices" in c
|
||||
assert 0 <= c["choices"][0]["index"] <= 4
|
||||
index = c["choices"][0]["index"]
|
||||
assert 0 <= index <= 4
|
||||
strings[index] += c["choices"][0]["text"]
|
||||
|
||||
assert response.status == 200
|
||||
assert list(strings) == [
|
||||
" A Beginner’s Guide\nDeep learning is a subset",
|
||||
" This is a question that has puzzled many people for",
|
||||
" Paris\nWhat is the capital of France?\nThe",
|
||||
'usculas_minusculas(s):\n """\n',
|
||||
]
|
||||
assert chunks == response_snapshot
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
|
@ -6,9 +6,10 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
pydantic = "> 2, < 3"
|
||||
python = ">=3.9,<3.13"
|
||||
python = ">=3.10,<3.13"
|
||||
syrupy = "^4.7.1"
|
||||
text-generation = "^0.6.0"
|
||||
pytest = "^7.4.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
docker = "^6.1.3"
|
||||
docker = "^7"
|
||||
numpy = "^1.20"
|
||||
|
|
|
@ -1,34 +1,35 @@
|
|||
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
|
||||
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
frozenlist==1.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
websocket-client==1.6.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
yarl==1.9.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
aiohappyeyeballs==2.4.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
aiohttp==3.10.5 ; python_version >= "3.10" and python_version < "3.13"
|
||||
aiosignal==1.3.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
|
||||
attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
docker==7.1.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
|
||||
filelock==3.16.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
fsspec==2024.9.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
|
||||
idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
|
||||
iniconfig==2.0.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
multidict==6.1.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
pluggy==1.5.0 ; python_version >= "3.10" and python_version < "3.13"
|
||||
pydantic-core==2.23.3 ; python_version >= "3.10" and python_version < "3.13"
|
||||
pydantic==2.9.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
pytest-asyncio==0.21.2 ; python_version >= "3.10" and python_version < "3.13"
|
||||
pytest==7.4.4 ; python_version >= "3.10" and python_version < "3.13"
|
||||
pywin32==306 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
|
||||
pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13"
|
||||
syrupy==4.7.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
text-generation==0.6.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
tomli==2.0.1 ; python_version >= "3.10" and python_version < "3.11"
|
||||
tqdm==4.66.5 ; python_version >= "3.10" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
|
||||
yarl==1.11.1 ; python_version >= "3.10" and python_version < "3.13"
|
||||
|
|
|
@ -1843,9 +1843,8 @@ fn main() -> Result<(), LauncherError> {
|
|||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
)
|
||||
.map_err(|err| {
|
||||
.inspect_err(|_| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
|
||||
// Default exit code
|
||||
|
|
|
@ -28,6 +28,9 @@ defaultCrateOverrides
|
|||
];
|
||||
};
|
||||
};
|
||||
pyo3-build-config = attrs: {
|
||||
buildInputs = [ python3 ];
|
||||
};
|
||||
text-generation-benchmark = attrs: {
|
||||
src = filter {
|
||||
root = ../benchmark;
|
||||
|
|
|
@ -61,6 +61,7 @@ uuid = { version = "1.9.1", default-features = false, features = [
|
|||
] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
pyo3 = { version = "0.22.2", features = ["auto-initialize"] }
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -336,6 +336,8 @@ pub enum InferError {
|
|||
ValidationError(#[from] ValidationError),
|
||||
#[error("Incomplete generation")]
|
||||
IncompleteGeneration,
|
||||
#[error("Incomplete generation stream")]
|
||||
IncompleteGenerationStream,
|
||||
#[error("Template error: {0}")]
|
||||
TemplateError(#[from] minijinja::Error),
|
||||
#[error("Missing template vatiable: {0}")]
|
||||
|
@ -351,6 +353,7 @@ impl InferError {
|
|||
InferError::Overloaded(_) => "overloaded",
|
||||
InferError::ValidationError(_) => "validation",
|
||||
InferError::IncompleteGeneration => "incomplete_generation",
|
||||
InferError::IncompleteGenerationStream => "incomplete_generation_stream",
|
||||
InferError::TemplateError(_) => "template_error",
|
||||
InferError::MissingTemplateVariable(_) => "missing_template_variable",
|
||||
InferError::ToolError(_) => "tool_error",
|
||||
|
|
|
@ -41,6 +41,7 @@ use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
|||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use pyo3::types::IntoPyDict;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::File;
|
||||
|
@ -48,7 +49,6 @@ use std::io::BufReader;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
use tokenizers::processors::template::TemplateProcessing;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
use tokio::signal;
|
||||
|
@ -318,7 +318,10 @@ pub(crate) async fn generate_internal(
|
|||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||
tracing::debug!(
|
||||
"Input: {}",
|
||||
&req.inputs.chars().take(1000).collect::<String>()
|
||||
);
|
||||
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
|
@ -674,7 +677,7 @@ async fn generate_stream_internal(
|
|||
// Check if generation reached the end
|
||||
// Skip if we already sent an error
|
||||
if !end_reached && !error {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
let err = InferError::IncompleteGenerationStream;
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
|
@ -1857,18 +1860,34 @@ pub async fn run(
|
|||
});
|
||||
|
||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||
if let Some(tokenizer) = &mut tokenizer {
|
||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||
tokenizer.with_post_processor(post_processor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenizer
|
||||
use pyo3::prelude::*;
|
||||
let convert = pyo3::Python::with_gil(|py| -> PyResult<()> {
|
||||
let transformers = py.import_bound("transformers")?;
|
||||
let auto = transformers.getattr("AutoTokenizer")?;
|
||||
let from_pretrained = auto.getattr("from_pretrained")?;
|
||||
let args = (tokenizer_name.to_string(),);
|
||||
let kwargs = [(
|
||||
"revision",
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
)]
|
||||
.into_py_dict_bound(py);
|
||||
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
|
||||
let save = tokenizer.getattr("save_pretrained")?;
|
||||
let args = ("out".to_string(),);
|
||||
save.call1(args)?;
|
||||
Ok(())
|
||||
})
|
||||
.inspect_err(|err| {
|
||||
tracing::error!("Failed to import python tokenizer {err}");
|
||||
});
|
||||
let filename = if convert.is_ok() {
|
||||
// If we have correctly loaded and resaved with transformers
|
||||
// We might have modified the tokenizer.json according to transformers
|
||||
"out/tokenizer.json".into()
|
||||
} else {
|
||||
filename
|
||||
};
|
||||
Tokenizer::from_file(filename).ok()
|
||||
});
|
||||
|
||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||
|
@ -2555,6 +2574,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||
InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
|
||||
InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
|
@ -2587,77 +2607,6 @@ pub enum WebServerError {
|
|||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
||||
/// Create a post_processor for the LlamaTokenizer
|
||||
fn create_post_processor(
|
||||
tokenizer: &Tokenizer,
|
||||
tokenizer_config: &HubTokenizerConfig,
|
||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
||||
|
||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
||||
|
||||
if add_bos_token && bos_token.is_none() {
|
||||
panic!("add_bos_token = true but bos_token is None");
|
||||
}
|
||||
|
||||
if add_eos_token && eos_token.is_none() {
|
||||
panic!("add_eos_token = true but eos_token is None");
|
||||
}
|
||||
|
||||
let mut single = Vec::new();
|
||||
let mut pair = Vec::new();
|
||||
let mut special_tokens = Vec::new();
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
let bos_token_id = tokenizer
|
||||
.token_to_id(bos.as_str())
|
||||
.expect("Should have found the bos token id");
|
||||
special_tokens.push((bos.as_str(), bos_token_id));
|
||||
single.push(format!("{}:0", bos.as_str()));
|
||||
pair.push(format!("{}:0", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
single.push("$A:0".to_string());
|
||||
pair.push("$A:0".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(eos.as_str())
|
||||
.expect("Should have found the eos token id");
|
||||
special_tokens.push((eos.as_str(), eos_token_id));
|
||||
single.push(format!("{}:0", eos.as_str()));
|
||||
pair.push(format!("{}:0", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
pair.push(format!("{}:1", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
pair.push("$B:1".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
pair.push(format!("{}:1", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
let post_processor = TemplateProcessing::builder()
|
||||
.try_single(single)?
|
||||
.try_pair(pair)?
|
||||
.special_tokens(special_tokens)
|
||||
.build()?;
|
||||
|
||||
Ok(post_processor)
|
||||
}
|
||||
|
||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||
|
||||
fn prepare_chat_input(
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
install-flashinfer:
|
||||
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4
|
||||
pip install flashinfer==0.1.6 -i https://flashinfer.ai/whl/cu124/torch2.4
|
||||
|
|
|
@ -267,7 +267,7 @@ def test_batch_concatenate(
|
|||
assert next_batch.max_input_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
assert next_batch.requests[1:] == list(next_batch_1.requests)
|
||||
|
||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
||||
|
|
|
@ -262,7 +262,7 @@ def test_batch_concatenate(
|
|||
assert next_batch.max_input_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
assert next_batch.requests[1:] == list(next_batch_1.requests)
|
||||
|
||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
||||
|
|
|
@ -281,7 +281,7 @@ def test_batch_concatenate(
|
|||
assert next_batch.max_decoder_input_length == 3
|
||||
|
||||
assert next_batch.requests[0] == next_batch_0.requests[0]
|
||||
assert next_batch.requests[1:] == next_batch_1.requests
|
||||
assert next_batch.requests[1:] == list(next_batch_1.requests)
|
||||
|
||||
assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]
|
||||
assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers
|
||||
|
|
|
@ -22,9 +22,9 @@ def attention(
|
|||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
q.contiguous() if q.device.type == "xpu" else q,
|
||||
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
||||
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
|
|
|
@ -82,7 +82,7 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
|
|||
import numa
|
||||
import psutil
|
||||
|
||||
nodes = numa.get_max_node() + 1
|
||||
nodes = numa.info.get_max_node() + 1
|
||||
rank_per_node = math.ceil(world_size / nodes)
|
||||
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
|
||||
node_id = int(rank_id / rank_per_node)
|
||||
|
@ -91,18 +91,22 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
|
|||
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
|
||||
else:
|
||||
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
|
||||
if len(numa.get_membind()) == nodes:
|
||||
numa.set_membind([node_id])
|
||||
if len(numa.memory.get_membind_nodes()) == nodes:
|
||||
numa.memory.set_membind_nodes((node_id))
|
||||
torch.set_num_threads(num_cpus_per_rank)
|
||||
if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True):
|
||||
if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
|
||||
cpu_start = num_cpus_per_rank * rank_offset_per_node
|
||||
numa.set_affinity(
|
||||
numa.schedule.run_on_cpus(
|
||||
0,
|
||||
list(numa.node_to_cpus(node_id))[
|
||||
cpu_start : cpu_start + num_cpus_per_rank
|
||||
],
|
||||
*(
|
||||
numa.info.node_to_cpus(node_id)[
|
||||
cpu_start : cpu_start + num_cpus_per_rank
|
||||
]
|
||||
),
|
||||
)
|
||||
logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}")
|
||||
logger.info(
|
||||
f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -272,6 +276,8 @@ class FlashCausalLMBatch(Batch):
|
|||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
|
||||
# Commented as it's costly.
|
||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||
prefix_ids.append(tokenized_input[:prefix_len])
|
||||
tokenized_input = tokenized_input[prefix_len:]
|
||||
|
||||
|
@ -515,6 +521,7 @@ class FlashCausalLMBatch(Batch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
assert len(pb.requests) > 0
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
|
@ -640,6 +647,7 @@ class FlashCausalLMBatch(Batch):
|
|||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
|
@ -834,6 +842,8 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
start_slots = torch.concat(start_slots)
|
||||
|
||||
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters,
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
|
@ -1152,27 +1162,6 @@ class FlashCausalLM(Model):
|
|||
input_lengths=input_lengths,
|
||||
prefix_lens=prefix_lengths,
|
||||
)
|
||||
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"prefix_lengths": prefix_lengths_tensor,
|
||||
}
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
@ -1189,21 +1178,38 @@ class FlashCausalLM(Model):
|
|||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
self.cuda_graphs[bs]["state"] = state
|
||||
else:
|
||||
state = None
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"prefix_lengths": prefix_lengths_tensor,
|
||||
"state": state,
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
prefix_lens=prefix_lengths,
|
||||
prefix_lens_tensor=prefix_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1216,6 +1222,7 @@ class FlashCausalLM(Model):
|
|||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
del seqlen
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
@ -1481,9 +1488,7 @@ class FlashCausalLM(Model):
|
|||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=input_lengths + prefix_lens_tensor,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
|
@ -1521,26 +1526,28 @@ class FlashCausalLM(Model):
|
|||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
# assert block_tables.shape[0] >= slots.shape[0]
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["prefix_lengths"].zero_()
|
||||
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
state=cuda_graph.get("state"),
|
||||
prefix_lens_tensor=cuda_graph["prefix_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
@ -1769,7 +1776,7 @@ class FlashCausalLM(Model):
|
|||
left = 0
|
||||
|
||||
if n_accepted_ids > 1:
|
||||
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
|
@ -1924,9 +1931,7 @@ class FlashCausalLM(Model):
|
|||
*,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
input_lengths: List[int],
|
||||
input_lengths_tensor: torch.Tensor,
|
||||
prefix_lens: List[int],
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
state: Optional[Any] = None,
|
||||
) -> ContextManager:
|
||||
|
@ -1952,7 +1957,7 @@ class FlashCausalLM(Model):
|
|||
# ),
|
||||
block_tables=block_tables,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths_tensor,
|
||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
|
@ -1962,7 +1967,7 @@ class FlashCausalLM(Model):
|
|||
assert input_lengths_tensor is not None
|
||||
return use_decode_state(
|
||||
state=state if state is not None else self.decode_state,
|
||||
input_lengths=input_lengths_tensor,
|
||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
|
|
|
@ -367,9 +367,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths=batch.input_lengths,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
|
|
|
@ -77,12 +77,12 @@ def load_and_merge_adapters(
|
|||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
|
||||
if len(adapter_parameters.adapter_info) == 1:
|
||||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||
adapter = next(iter(adapter_parameters.adapter_info))
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_info.revision,
|
||||
adapter_info.id,
|
||||
adapter_info.path,
|
||||
adapter.revision,
|
||||
adapter.id,
|
||||
adapter.path,
|
||||
weight_names,
|
||||
trust_remote_code,
|
||||
)
|
||||
|
@ -90,7 +90,6 @@ def load_and_merge_adapters(
|
|||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||
return _load_and_merge(
|
||||
model_id,
|
||||
adapter_params.revision,
|
||||
adapter_params,
|
||||
weight_names,
|
||||
trust_remote_code,
|
||||
|
@ -109,7 +108,6 @@ class AdapterParametersContainer:
|
|||
@lru_cache(maxsize=32)
|
||||
def _load_and_merge(
|
||||
model_id: str,
|
||||
revision: str,
|
||||
adapter_params: AdapterParametersContainer,
|
||||
weight_names: Tuple[str],
|
||||
trust_remote_code: bool = False,
|
||||
|
@ -126,6 +124,7 @@ def _load_and_merge(
|
|||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
model_id,
|
||||
adapter.revision,
|
||||
adapter.id,
|
||||
adapter.path,
|
||||
weight_names,
|
||||
|
|
Loading…
Reference in New Issue