feat(router): refactor API and add openAPI schemas (#53)

This commit is contained in:
OlivierDehaene 2023-02-03 12:43:37 +01:00 committed by GitHub
parent b1482d9048
commit 20c3c5940c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1355 additions and 496 deletions

View File

@ -5,6 +5,8 @@ on:
push: push:
branches: branches:
- 'main' - 'main'
tags:
- 'v*'
pull_request: pull_request:
branches: branches:
- 'main' - 'main'
@ -43,6 +45,8 @@ jobs:
ghcr.io/huggingface/text-generation-inference ghcr.io/huggingface/text-generation-inference
registry.internal.huggingface.tech/api-inference/community/text-generation-inference registry.internal.huggingface.tech/api-inference/community/text-generation-inference
tags: | tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }} type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
- name: Build and push Docker image - name: Build and push Docker image

189
Cargo.lock generated
View File

@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.5.17" version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@ -101,8 +101,10 @@ dependencies = [
"mime", "mime",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rustversion",
"serde", "serde",
"serde_json", "serde_json",
"serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
@ -114,9 +116,9 @@ dependencies = [
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.2.9" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc" checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bytes", "bytes",
@ -124,6 +126,7 @@ dependencies = [
"http", "http",
"http-body", "http-body",
"mime", "mime",
"rustversion",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
] ]
@ -207,7 +210,7 @@ dependencies = [
"tar", "tar",
"tempfile", "tempfile",
"thiserror", "thiserror",
"zip", "zip 0.5.13",
"zip-extensions", "zip-extensions",
] ]
@ -465,6 +468,15 @@ dependencies = [
"dirs-sys", "dirs-sys",
] ]
[[package]]
name = "dirs"
version = "4.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
dependencies = [
"dirs-sys",
]
[[package]] [[package]]
name = "dirs-sys" name = "dirs-sys"
version = "0.3.7" version = "0.3.7"
@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"hashbrown", "hashbrown",
"serde",
] ]
[[package]] [[package]]
@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.5.0" version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]] [[package]]
name = "memchr" name = "memchr"
@ -1024,6 +1037,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "mime_guess"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
dependencies = [
"mime",
"unicase",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@ -1552,12 +1575,62 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "rust-embed"
version = "6.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "6.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"shellexpand",
"syn",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "7.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rustversion"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.11" version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "schannel" name = "schannel"
version = "0.1.20" version = "0.1.20"
@ -1628,6 +1701,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "serde_path_to_error"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde_urlencoded" name = "serde_urlencoded"
version = "0.7.1" version = "0.7.1"
@ -1660,6 +1742,15 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shellexpand"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
dependencies = [
"dirs 4.0.0",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.0" version = "1.4.0"
@ -1797,7 +1888,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-client" name = "text-generation-client"
version = "0.1.0" version = "0.2.0"
dependencies = [ dependencies = [
"futures", "futures",
"prost", "prost",
@ -1812,7 +1903,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.1.0" version = "0.2.0"
dependencies = [ dependencies = [
"clap 4.0.22", "clap 4.0.22",
"ctrlc", "ctrlc",
@ -1827,7 +1918,7 @@ dependencies = [
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "0.1.0" version = "0.2.0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"axum", "axum",
@ -1845,6 +1936,8 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
] ]
[[package]] [[package]]
@ -1921,7 +2014,7 @@ dependencies = [
"cached-path", "cached-path",
"clap 2.34.0", "clap 2.34.0",
"derive_builder", "derive_builder",
"dirs", "dirs 3.0.2",
"esaxx-rs", "esaxx-rs",
"getrandom", "getrandom",
"indicatif 0.15.0", "indicatif 0.15.0",
@ -2234,6 +2327,15 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicase"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
dependencies = [
"version_check",
]
[[package]] [[package]]
name = "unicode-bidi" name = "unicode-bidi"
version = "0.3.8" version = "0.3.8"
@ -2293,6 +2395,46 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utoipa"
version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
dependencies = [
"indexmap",
"serde",
"serde_json",
"utoipa-gen",
]
[[package]]
name = "utoipa-gen"
version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "utoipa-swagger-ui"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2"
dependencies = [
"axum",
"mime_guess",
"regex",
"rust-embed",
"serde",
"serde_json",
"utoipa",
"zip 0.6.4",
]
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.0" version = "0.1.0"
@ -2317,6 +2459,17 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
dependencies = [
"same-file",
"winapi",
"winapi-util",
]
[[package]] [[package]]
name = "want" name = "want"
version = "0.3.0" version = "0.3.0"
@ -2589,11 +2742,23 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "zip"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
dependencies = [
"byteorder",
"crc32fast",
"crossbeam-utils",
"flate2",
]
[[package]] [[package]]
name = "zip-extensions" name = "zip-extensions"
version = "0.6.1" version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14" checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
dependencies = [ dependencies = [
"zip", "zip 0.5.13",
] ]

View File

@ -4,9 +4,6 @@ members = [
"router/client", "router/client",
"launcher" "launcher"
] ]
exclude = [
"server/safetensors",
]
[profile.release] [profile.release]
debug = 1 debug = 1

View File

@ -26,21 +26,18 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \ LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \ DEBIAN_FRONTEND=noninteractive \
MODEL_BASE_PATH=/data \ HUGGINGFACE_HUB_CACHE=/data \
MODEL_ID=bigscience/bloom-560m \ MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \ QUANTIZE=false \
NUM_GPUS=1 \ NUM_SHARD=1 \
SAFETENSORS_FAST_GPU=1 \ SAFETENSORS_FAST_GPU=1 \
PORT=80 \ PORT=80 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \ NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \ CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \ LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \ CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/* RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \ RUN cd ~ && \
@ -71,4 +68,5 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
# Install launcher # Install launcher
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --model-name $MODEL_ID --json-output ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -15,17 +15,23 @@ server-dev:
router-dev: router-dev:
cd router && cargo run cd router && cargo run
integration-tests: install-router install-launcher
cargo test
python-tests:
cd server && pytest tests
run-bloom-560m: run-bloom-560m:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
run-bloom-560m-quantize: run-bloom-560m-quantize:
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
download-bloom: download-bloom:
text-generation-server download-weights bigscience/bloom text-generation-server download-weights bigscience/bloom
run-bloom: run-bloom:
text-generation-launcher --model-name bigscience/bloom --num-shard 8 text-generation-launcher --model-id bigscience/bloom --num-shard 8
run-bloom-quantize: run-bloom-quantize:
text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize

113
README.md
View File

@ -1,16 +1,43 @@
<div align="center">
# Text Generation Inference # Text Generation Inference
<div align="center"> <a href="https://github.com/huggingface/text-generation-inference">
<img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/huggingface/text-generation-inference?style=social">
</a>
<a href="https://github.com/huggingface/text-generation-inference/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/huggingface/text-generation-inference">
</a>
<a href="https://huggingface.github.io/text-generation-inference">
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
</a>
![architecture](assets/architecture.jpg) ![architecture](assets/architecture.jpg)
</div> </div>
A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
to power Bloom, BloomZ and MT0-XXL api-inference widgets. to power LLMs api-inference widgets.
## Table of contents
- [Features](#features)
- [Officially Supported Models](#officially-supported-models)
- [Get Started](#get-started)
- [Docker](#docker)
- [Local Install](#local-install)
- [OpenAPI](#api-documentation)
- [CUDA Kernels](#cuda-kernels)
- [Run BLOOM](#run-bloom)
- [Download](#download)
- [Run](#run)
- [Quantization](#quantization)
- [Develop](#develop)
- [Testing](#testing)
## Features ## Features
- Token streaming using Server Side Events (SSE)
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput - [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
@ -36,30 +63,63 @@ or
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")` `AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
## Load Tests for BLOOM ## Get started
See `k6/load_test.js` ### Docker
| | avg | min | med | max | p(90) | p(95) | RPS | The easiest way of getting started is using the official Docker container:
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
## Install
```shell ```shell
make install model=bigscience/bloom-560m
num_shard=2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
``` ```
## Run You can then query the model using either the `/generate` or `/generate_stream` routes:
### BLOOM 560-m
```shell ```shell
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
```
```shell
curl 127.0.0.1:8080/generate_stream \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
```
To use GPUs, you will need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
### API documentation
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Local install
You can also opt to install `text-generation-inference` locally. You will need to have cargo and Python installed on your
machine
```shell
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
make run-bloom-560m make run-bloom-560m
``` ```
### BLOOM ### CUDA Kernels
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
Be aware that the official Docker image has them enabled by default.
## Run BLOOM
### Download
First you need to download the weights: First you need to download the weights:
@ -67,29 +127,30 @@ First you need to download the weights:
make download-bloom make download-bloom
``` ```
### Run
```shell ```shell
make run-bloom # Requires 8xA100 80GB make run-bloom # Requires 8xA100 80GB
``` ```
### Quantization
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
```shell ```shell
make run-bloom-quantize # Requires 8xA100 40GB make run-bloom-quantize # Requires 8xA100 40GB
``` ```
## Test
```shell
curl 127.0.0.1:3000/generate \
-v \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
```
## Develop ## Develop
```shell ```shell
make server-dev make server-dev
make router-dev make router-dev
```
## Testing
```shell
make python-tests
make integration-tests
``` ```

View File

@ -4,9 +4,9 @@ endpoint_name: bloom-inference
model: azureml:bloom:1 model: azureml:bloom:1
model_mount_path: /var/azureml-model model_mount_path: /var/azureml-model
environment_variables: environment_variables:
MODEL_BASE_PATH: /var/azureml-model/bloom HUGGINGFACE_HUB_CACHE: /var/azureml-model/bloom
MODEL_ID: bigscience/bloom MODEL_ID: bigscience/bloom
NUM_GPUS: 8 NUM_SHARD: 8
environment: environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1 image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
inference_config: inference_config:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

After

Width:  |  Height:  |  Size: 334 KiB

30
docs/index.html Normal file
View File

@ -0,0 +1,30 @@
<html>
<head>
<!-- Load the latest Swagger UI code and style from npm using unpkg.com -->
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"/>
<title>Text Generation Inference API</title>
</head>
<body>
<div id="swagger-ui"></div> <!-- Div to hold the UI component -->
<script>
window.onload = function () {
// Begin Swagger UI call region
const ui = SwaggerUIBundle({
url: "openapi.json", //Location of Open API spec in the repo
dom_id: '#swagger-ui',
deepLinking: true,
supportedSubmitMethods: [],
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
})
window.ui = ui
}
</script>
</body>
</html>

446
docs/openapi.json Normal file
View File

@ -0,0 +1,446 @@
{
"openapi": "3.0.3",
"info": {
"title": "Text Generation Inference",
"description": "Text Generation Webserver",
"contact": {
"name": "Olivier Dehaene",
"email": ""
},
"license": {
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "0.2.0"
},
"paths": {
"/generate": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "generate",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Text",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/GenerateResponse"
}
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Incomplete generation"
}
}
}
}
},
"deprecated": false
}
},
"/generate_stream": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate a stream of token using Server Side Events",
"description": "Generate a stream of token using Server Side Events",
"operationId": "generate_stream",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/GenerateRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Text",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/StreamResponse"
}
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Input validation error"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Request failed during generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Model is overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"text/event-stream ": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ErrorResponse"
}
},
"example": {
"error": "Incomplete generation"
}
}
}
}
},
"deprecated": false
}
}
},
"components": {
"schemas": {
"Details": {
"type": "object",
"required": [
"finish_reason",
"generated_tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"prefill": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42
},
"tokens": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
}
}
},
"ErrorResponse": {
"type": "object",
"required": [
"error"
],
"properties": {
"error": {
"type": "string"
}
}
},
"FinishReason": {
"type": "string",
"enum": [
"length",
"eos_token",
"stop_sequence"
]
},
"GenerateParameters": {
"type": "object",
"properties": {
"details": {
"type": "boolean",
"default": "true"
},
"do_sample": {
"type": "boolean",
"default": "false",
"example": true
},
"max_new_tokens": {
"type": "integer",
"format": "int32",
"default": "20",
"exclusiveMaximum": 512.0,
"exclusiveMinimum": 0.0
},
"repetition_penalty": {
"type": "number",
"format": "float",
"default": "null",
"example": 1.03,
"nullable": true,
"exclusiveMinimum": 0.0
},
"seed": {
"type": "integer",
"format": "int64"
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"example": [
"photographer"
],
"maxItems": 4
},
"temperature": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.5,
"nullable": true,
"exclusiveMinimum": 0.0
},
"top_k": {
"type": "integer",
"format": "int32",
"default": "null",
"example": 10,
"nullable": true,
"exclusiveMinimum": 0.0
},
"top_p": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.95,
"nullable": true,
"maximum": 1.0,
"exclusiveMinimum": 0.0
}
}
},
"GenerateRequest": {
"type": "object",
"required": [
"inputs"
],
"properties": {
"inputs": {
"type": "string",
"example": "My name is Olivier and I"
},
"parameters": {
"$ref": "#/components/schemas/GenerateParameters"
}
}
},
"GenerateResponse": {
"type": "object",
"required": [
"generated_text"
],
"properties": {
"details": {
"$ref": "#/components/schemas/Details"
},
"generated_text": {
"type": "string",
"example": "test"
}
}
},
"StreamDetails": {
"type": "object",
"required": [
"finish_reason",
"generated_tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42
}
}
},
"StreamResponse": {
"type": "object",
"required": [
"token"
],
"properties": {
"details": {
"$ref": "#/components/schemas/StreamDetails"
},
"generated_text": {
"type": "string",
"default": "null",
"example": "test",
"nullable": true
},
"token": {
"$ref": "#/components/schemas/Token"
}
}
},
"Token": {
"type": "object",
"required": [
"id",
"text",
"logprob"
],
"properties": {
"id": {
"type": "integer",
"format": "int32",
"example": 0
},
"logprob": {
"type": "number",
"format": "float",
"example": -0.34,
"nullable": true
},
"text": {
"type": "string",
"example": "test"
}
}
}
}
},
"tags": [
{
"name": "Text Generation Inference",
"description": "Hugging Face Text Generation Inference API"
}
]
}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "text-generation-launcher" name = "text-generation-launcher"
version = "0.1.0" version = "0.2.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
description = "Text Generation Launcher" description = "Text Generation Launcher"

View File

@ -19,7 +19,7 @@ use subprocess::{Popen, PopenConfig, PopenError, Redirection};
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)] #[clap(default_value = "bigscience/bloom-560m", long, env)]
model_name: String, model_id: String,
#[clap(long, env)] #[clap(long, env)]
revision: Option<String>, revision: Option<String>,
#[clap(long, env)] #[clap(long, env)]
@ -49,7 +49,7 @@ struct Args {
fn main() -> ExitCode { fn main() -> ExitCode {
// Pattern match configuration // Pattern match configuration
let Args { let Args {
model_name, model_id,
revision, revision,
num_shard, num_shard,
quantize, quantize,
@ -92,7 +92,7 @@ fn main() -> ExitCode {
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
let model_name = model_name.clone(); let model_id = model_id.clone();
let revision = revision.clone(); let revision = revision.clone();
let uds_path = shard_uds_path.clone(); let uds_path = shard_uds_path.clone();
let master_addr = master_addr.clone(); let master_addr = master_addr.clone();
@ -101,7 +101,7 @@ fn main() -> ExitCode {
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_name, model_id,
revision, revision,
quantize, quantize,
uds_path, uds_path,
@ -167,7 +167,7 @@ fn main() -> ExitCode {
"--master-shard-uds-path".to_string(), "--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path), format!("{}-0", shard_uds_path),
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
model_name, model_id,
]; ];
if json_output { if json_output {
@ -256,7 +256,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn shard_manager( fn shard_manager(
model_name: String, model_id: String,
revision: Option<String>, revision: Option<String>,
quantize: bool, quantize: bool,
uds_path: String, uds_path: String,
@ -278,7 +278,7 @@ fn shard_manager(
let mut shard_argv = vec![ let mut shard_argv = vec![
"text-generation-server".to_string(), "text-generation-server".to_string(),
"serve".to_string(), "serve".to_string(),
model_name, model_id,
"--uds-path".to_string(), "--uds-path".to_string(),
uds_path, uds_path,
"--logger-level".to_string(), "--logger-level".to_string(),

View File

@ -1,123 +1,122 @@
[ {
{ "details": {
"details": { "finish_reason": "length",
"finish_reason": "length", "generated_tokens": 20,
"generated_tokens": 20, "prefill": [
"prefill": [ {
[ "id": 10264,
10264, "logprob": null,
"Test", "text": "Test"
null },
], {
[ "id": 8821,
8821, "logprob": -11.894989,
" request", "text": " request"
-11.895094 }
] ],
], "seed": null,
"tokens": [ "tokens": [
[ {
17, "id": 17,
".", "logprob": -1.8267672,
-1.8267941 "text": "."
], },
[ {
1587, "id": 1587,
"get", "logprob": -2.4674969,
-2.4674964 "text": "get"
], },
[ {
11, "id": 11,
"(", "logprob": -1.906001,
-1.9060438 "text": "("
], },
[ {
5, "id": 5,
"\"", "logprob": -1.2279545,
-1.2279553 "text": "\""
], },
[ {
4899, "id": 4899,
"action", "logprob": -4.170299,
-4.170306 "text": "action"
], },
[ {
5, "id": 5,
"\"", "logprob": -0.32478866,
-0.3247902 "text": "\""
], },
[ {
12, "id": 12,
")", "logprob": -1.0773665,
-1.0773602 "text": ")"
], },
[ {
30, "id": 30,
";", "logprob": -0.27640742,
-0.27640444 "text": ";"
], },
[ {
837, "id": 837,
"\n ", "logprob": -1.6970354,
-1.6970599 "text": "\n "
], },
[ {
1320, "id": 1320,
" if", "logprob": -1.4495516,
-1.4495552 "text": " if"
], },
[ {
375, "id": 375,
" (", "logprob": -0.23609057,
-0.2360998 "text": " ("
], },
[ {
4899, "id": 4899,
"action", "logprob": -1.1916996,
-1.1916926 "text": "action"
], },
[ {
3535, "id": 3535,
" ==", "logprob": -0.8918753,
-0.8918663 "text": " =="
], },
[ {
5109, "id": 5109,
" null", "logprob": -0.3933342,
-0.39334255 "text": " null"
], },
[ {
12, "id": 12,
")", "logprob": -0.43212673,
-0.4321134 "text": ")"
], },
[ {
731, "id": 731,
" {", "logprob": -0.17702064,
-0.17701954 "text": " {"
], },
[ {
1260, "id": 1260,
"\n ", "logprob": -0.07027565,
-0.07027287 "text": "\n "
], },
[ {
10519, "id": 10519,
" throw", "logprob": -1.3915029,
-1.3915133 "text": " throw"
], },
[ {
2084, "id": 2084,
" new", "logprob": -0.04201372,
-0.042013377 "text": " new"
], },
[ {
150858, "id": 150858,
" RuntimeException", "logprob": -1.7329919,
-1.7330077 "text": " RuntimeException"
] }
] ]
}, },
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }
]

View File

@ -9,11 +9,18 @@ use std::thread::sleep;
use std::time::Duration; use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection}; use subprocess::{Popen, PopenConfig, Redirection};
#[derive(Deserialize)]
pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Details { struct Details {
finish_reason: String, finish_reason: String,
generated_tokens: u32, generated_tokens: u32,
tokens: Vec<(u32, String, Option<f32>)>, tokens: Vec<Token>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -22,11 +29,11 @@ struct GeneratedText {
details: Details, details: Details,
} }
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen { fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![ let argv = vec![
"text-generation-launcher".to_string(), "text-generation-launcher".to_string(),
"--model-name".to_string(), "--model-id".to_string(),
model_name.clone(), model_id.clone(),
"--num-shard".to_string(), "--num-shard".to_string(),
num_shard.to_string(), num_shard.to_string(),
"--port".to_string(), "--port".to_string(),
@ -68,16 +75,16 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
launcher.terminate().unwrap(); launcher.terminate().unwrap();
launcher.wait().unwrap(); launcher.wait().unwrap();
panic!("failed to launch {}", model_name) panic!("failed to launch {}", model_id)
} }
fn test_model( fn test_model(
model_name: String, model_id: String,
num_shard: usize, num_shard: usize,
port: usize, port: usize,
master_port: usize, master_port: usize,
) -> GeneratedText { ) -> GeneratedText {
let mut launcher = start_launcher(model_name, num_shard, port, master_port); let mut launcher = start_launcher(model_id, num_shard, port, master_port);
let data = r#" let data = r#"
{ {
@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText {
let file = File::open(d).unwrap(); let file = File::open(d).unwrap();
let reader = BufReader::new(file); let reader = BufReader::new(file);
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap(); let result: GeneratedText = serde_json::from_reader(reader).unwrap();
results.pop().unwrap() result
} }
fn compare_results(result: GeneratedText, expected: GeneratedText) { fn compare_results(result: GeneratedText, expected: GeneratedText) {
@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
.into_iter() .into_iter()
.zip(expected.details.tokens.into_iter()) .zip(expected.details.tokens.into_iter())
{ {
assert_eq!(token.0, expected_token.0); assert_eq!(token.id, expected_token.id);
assert_eq!(token.1, expected_token.1); assert_eq!(token.text, expected_token.text);
if let Some(logprob) = token.2 { if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.2.unwrap(); let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001); assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else { } else {
assert_eq!(token.2, expected_token.2); assert_eq!(token.logprob, expected_token.logprob);
} }
} }
} }

View File

@ -1,118 +1,117 @@
[ {
{ "details": {
"details": { "finish_reason": "length",
"finish_reason": "length", "generated_tokens": 20,
"generated_tokens": 20, "prefill": [
"prefill": [ {
[ "id": 0,
0, "logprob": null,
"<pad>", "text": "<pad>"
null }
] ],
], "seed": null,
"tokens": [ "tokens": [
[ {
259, "id": 259,
"", "logprob": -1.3656927,
-1.3656927 "text": ""
], },
[ {
215100, "id": 215100,
"\"\"\"", "logprob": -2.6551573,
-2.6551573 "text": "\"\"\""
], },
[ {
46138, "id": 46138,
"Test", "logprob": -1.8059857,
-1.8059857 "text": "Test"
], },
[ {
287, "id": 287,
"the", "logprob": -1.2102449,
-1.2102449 "text": "the"
], },
[ {
259, "id": 259,
"", "logprob": -1.6057279,
-1.6057279 "text": ""
], },
[ {
49076, "id": 49076,
"contents", "logprob": -3.6060903,
-3.6060903 "text": "contents"
], },
[ {
304, "id": 304,
"of", "logprob": -0.5270343,
-0.5270343 "text": "of"
], },
[ {
287, "id": 287,
"the", "logprob": -0.62522805,
-0.62522805 "text": "the"
], },
[ {
259, "id": 259,
"", "logprob": -1.4069618,
-1.4069618 "text": ""
], },
[ {
49076, "id": 49076,
"contents", "logprob": -2.621994,
-2.621994 "text": "contents"
], },
[ {
304, "id": 304,
"of", "logprob": -1.3172221,
-1.3172221 "text": "of"
], },
[ {
287, "id": 287,
"the", "logprob": -0.3501925,
-0.3501925 "text": "the"
], },
[ {
259, "id": 259,
"", "logprob": -0.7219573,
-0.7219573 "text": ""
], },
[ {
49076, "id": 49076,
"contents", "logprob": -1.0494149,
-1.0494149 "text": "contents"
], },
[ {
260, "id": 260,
".", "logprob": -1.0803378,
-1.0803378 "text": "."
], },
[ {
259, "id": 259,
"", "logprob": -0.32933083,
-0.32933083 "text": ""
], },
[ {
215100, "id": 215100,
"\"\"\"", "logprob": -0.11268901,
-0.11268901 "text": "\"\"\""
], },
[ {
2978, "id": 2978,
"test", "logprob": -1.5846587,
-1.5846587 "text": "test"
], },
[ {
290, "id": 290,
"_", "logprob": -0.49796978,
-0.49796978 "text": "_"
], },
[ {
4125, "id": 4125,
"test", "logprob": -2.0026445,
-2.0026445 "text": "test"
] }
] ]
}, },
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
} }
]

View File

@ -71,13 +71,19 @@ message Batch {
uint32 size = 3; uint32 size = 3;
} }
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText { message GeneratedText {
/// Output /// Output
string text = 1; string text = 1;
/// Number of generated tokens /// Number of generated tokens
uint32 generated_tokens = 2; uint32 generated_tokens = 2;
/// Finish reason /// Finish reason
string finish_reason = 3; FinishReason finish_reason = 3;
/// Seed /// Seed
optional uint64 seed = 4; optional uint64 seed = 4;
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "text-generation-router" name = "text-generation-router"
version = "0.1.0" version = "0.2.0"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
description = "Text Generation Webserver" description = "Text Generation Webserver"
@ -14,7 +14,7 @@ path = "src/main.rs"
[dependencies] [dependencies]
async-stream = "0.3.3" async-stream = "0.3.3"
axum = { version = "0.5.16", features = ["json", "serde_json"] } axum = { version = "0.6.4", features = ["json"] }
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.24"
@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot"
tokio-stream = "0.1.11" tokio-stream = "0.1.11"
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] } tracing-subscriber = { version = "0.3.15", features = ["json"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "text-generation-client" name = "text-generation-client"
version = "0.1.0" version = "0.2.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]

View File

@ -7,8 +7,8 @@ mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request, Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
StoppingCriteriaParameters, Request, StoppingCriteriaParameters,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -127,7 +127,7 @@ impl Infer {
.into_iter() .into_iter()
.zip(tokens.logprobs.into_iter()) .zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter()) .zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token(id, text, logprob)) .map(|((id, logprob), text)| Token { id, text, logprob })
.collect(); .collect();
} }
// Push last token // Push last token
@ -282,11 +282,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
} }
// Create last Token // Create last Token
let token = Token( let token = Token {
generation.token_id, id: generation.token_id,
generation.token_text, text: generation.token_text,
generation.token_logprob, logprob: generation.token_logprob,
); };
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Remove entry as this is the last message // Remove entry as this is the last message

View File

@ -1,5 +1,4 @@
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer; mod infer;
mod queue; mod queue;
pub mod server; pub mod server;
@ -8,45 +7,55 @@ mod validation;
use infer::Infer; use infer::Infer;
use queue::{Entry, Queue}; use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")] #[serde(default)]
pub temperature: f32, #[schema(
#[serde(default = "default_repetition_penalty")] exclusive_minimum = 0.0,
pub repetition_penalty: f32, nullable = true,
#[serde(default = "default_top_k")] default = "null",
pub top_k: i32, example = 0.5
#[serde(default = "default_top_p")] )]
pub top_p: f32, pub temperature: Option<f32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
nullable = true,
default = "null",
example = 1.03
)]
pub repetition_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
#[serde(default)]
#[schema(
exclusive_minimum = 0.0,
maximum = 1.0,
nullable = true,
default = "null",
example = 0.95
)]
pub top_p: Option<f32>,
#[serde(default = "default_do_sample")] #[serde(default = "default_do_sample")]
#[schema(default = "false", example = true)]
pub do_sample: bool, pub do_sample: bool,
#[serde(default = "default_max_new_tokens")] #[serde(default = "default_max_new_tokens")]
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }
fn default_temperature() -> f32 {
1.0
}
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool { fn default_do_sample() -> bool {
false false
} }
@ -57,10 +66,10 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
GenerateParameters { GenerateParameters {
temperature: default_temperature(), temperature: None,
repetition_penalty: default_repetition_penalty(), repetition_penalty: None,
top_k: default_top_k(), top_k: None,
top_p: default_top_p(), top_p: None,
do_sample: default_do_sample(), do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(), max_new_tokens: default_max_new_tokens(),
stop: vec![], stop: vec![],
@ -69,42 +78,77 @@ fn default_parameters() -> GenerateParameters {
} }
} }
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateRequest { pub(crate) struct GenerateRequest {
#[schema(example = "My name is Olivier and I")]
pub inputs: String, pub inputs: String,
#[serde(default = "default_parameters")] #[serde(default = "default_parameters")]
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, ToSchema)]
pub struct Token(u32, String, f32); pub struct Token {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32,
}
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
pub(crate) enum FinishReason {
#[schema(rename = "length")]
Length,
#[serde(rename = "eos_token")]
#[schema(rename = "eos_token")]
EndOfSequenceToken,
#[schema(rename = "stop_sequence")]
StopSequence,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct Details { pub(crate) struct Details {
pub finish_reason: String, #[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefill: Option<Vec<Token>>, pub prefill: Option<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tokens: Option<Vec<Token>>, pub tokens: Option<Vec<Token>>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct GenerateResponse { pub(crate) struct GenerateResponse {
#[schema(example = "test")]
pub generated_text: String, pub generated_text: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Details>, pub details: Option<Details>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamDetails {
pub token: Token, #[schema(example = "length")]
pub generated_text: Option<String>, pub finish_reason: FinishReason,
pub details: Option<Details>, #[schema(example = 1)]
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
} }
#[derive(Serialize)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse {
pub token: Token,
#[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")]
pub details: Option<StreamDetails>,
}
#[derive(Serialize, ToSchema)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
#[schema(inline)]
pub error: String, pub error: String,
} }

View File

@ -1,8 +1,8 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
StreamResponse, Validation, Infer, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, StatusCode};
@ -19,6 +19,8 @@ use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::instrument;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
/// Health check method /// Health check method
#[instrument(skip(infer))] #[instrument(skip(infer))]
@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate(GenerateRequest { .generate(GenerateRequest {
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
temperature: 1.0, temperature: None,
repetition_penalty: 1.0, repetition_penalty: None,
top_k: 0, top_k: None,
top_p: 1.0, top_p: None,
do_sample: false, do_sample: false,
max_new_tokens: 1, max_new_tokens: 1,
stop: vec![], stop: Vec::new(),
details: false, details: false,
seed: None, seed: None,
}, },
@ -47,7 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
Ok(()) Ok(())
} }
/// Generate method /// Generate tokens
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"})),
)
)]
#[instrument( #[instrument(
skip(infer), skip(infer),
fields( fields(
@ -76,7 +95,7 @@ async fn generate(
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => Some(Details {
finish_reason: response.generated_text.finish_reason, finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: Some(response.prefill), prefill: Some(response.prefill),
tokens: Some(response.tokens), tokens: Some(response.tokens),
@ -132,7 +151,29 @@ async fn generate(
Ok((headers, Json(response))) Ok((headers, Json(response)))
} }
/// Generate stream method /// Generate a stream of token using Server Side Events
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = [StreamResponse],
content_type="text/event-stream "),
(status = 424, description = "Generation Error", body = [ErrorResponse],
example = json!({"error": "Request failed during generation"}),
content_type="text/event-stream "),
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
example = json!({"error": "Model is overloaded"}),
content_type="text/event-stream "),
(status = 422, description = "Input validation error", body = [ErrorResponse],
example = json!({"error": "Input validation error"}),
content_type="text/event-stream "),
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "),
)
)]
#[instrument( #[instrument(
skip(infer), skip(infer),
fields( fields(
@ -185,11 +226,9 @@ async fn generate_stream(
} => { } => {
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => Some(StreamDetails {
finish_reason: generated_text.finish_reason, finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens, generated_tokens: generated_text.generated_tokens,
prefill: None,
tokens: None,
seed: generated_text.seed, seed: generated_text.seed,
}), }),
false => None, false => None,
@ -265,6 +304,39 @@ pub async fn run(
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
) { ) {
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
generate,
generate_stream,
),
components(
schemas(
GenerateRequest,
GenerateParameters,
Token,
GenerateResponse,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct ApiDoc;
// Create state // Create state
let validation = Validation::new(validation_workers, tokenizer, max_input_length); let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let infer = Infer::new( let infer = Infer::new(
@ -277,6 +349,7 @@ pub async fn run(
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/", post(generate)) .route("/", post(generate))
.route("/generate", post(generate)) .route("/generate", post(generate))
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
@ -320,6 +393,17 @@ async fn shutdown_signal() {
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
} }
impl From<i32> for FinishReason {
fn from(finish_reason: i32) -> Self {
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
match finish_reason {
text_generation_client::FinishReason::Length => FinishReason::Length,
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
}
}
}
/// Convert to Axum supported formats /// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) { impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
fn from(err: InferError) -> Self { fn from(err: InferError) -> Self {

View File

@ -110,30 +110,58 @@ fn validate(
max_input_length: usize, max_input_length: usize,
rng: &mut ThreadRng, rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
if request.parameters.temperature <= 0.0 { let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
seed,
..
} = request.parameters;
let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
} }
if request.parameters.repetition_penalty <= 0.0 {
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
if repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
let top_p = top_p.unwrap_or(1.0);
if top_p <= 0.0 || top_p > 1.0 {
return Err(ValidationError::TopP); return Err(ValidationError::TopP);
} }
if request.parameters.top_k < 0 {
return Err(ValidationError::TopK); // Different because the proto default value is 0 while it is not a valid value
} // for the user
if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS { let top_k: u32 = match top_k {
None => Ok(0),
Some(top_k) => {
if top_k <= 0 {
return Err(ValidationError::TopK);
}
Ok(top_k as u32)
}
}?;
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS)); return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
} }
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
if stop_sequences.len() > MAX_STOP_SEQUENCES {
return Err(ValidationError::StopSequence( return Err(ValidationError::StopSequence(
MAX_STOP_SEQUENCES, MAX_STOP_SEQUENCES,
request.parameters.stop.len(), stop_sequences.len(),
)); ));
} }
// If seed is None, assign a random one // If seed is None, assign a random one
let seed = match request.parameters.seed { let seed = match seed {
None => rng.gen(), None => rng.gen(),
Some(seed) => seed, Some(seed) => seed,
}; };
@ -147,21 +175,10 @@ fn validate(
Err(ValidationError::InputLength(input_length, max_input_length)) Err(ValidationError::InputLength(input_length, max_input_length))
} else { } else {
// Return ValidGenerateRequest // Return ValidGenerateRequest
let GenerateParameters {
temperature,
repetition_penalty,
top_k,
top_p,
do_sample,
max_new_tokens,
stop: stop_sequences,
..
} = request.parameters;
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
top_k: top_k as u32, top_k,
top_p, top_p,
do_sample, do_sample,
seed, seed,
@ -206,7 +223,7 @@ pub enum ValidationError {
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("top_k must be strictly positive")]
TopK, TopK,
#[error("max_new_tokens must be <= {0}")] #[error("max_new_tokens must be strictly positive and <= {0}")]
MaxNewTokens(u32), MaxNewTokens(u32),
#[error("inputs must have less than {1} tokens. Given: {0}")] #[error("inputs must have less than {1} tokens. Given: {0}")]
InputLength(usize, usize), InputLength(usize, usize),

View File

@ -1,6 +1,6 @@
# BLOOM Inference Python gRPC Server # Text Generation Inference Python gRPC Server
A Python gRPC server for BLOOM Inference A Python gRPC server for Text Generation Inference
## Install ## Install

View File

@ -1,7 +1,7 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.1.0" version = "0.2.0"
description = "BLOOM Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]

View File

@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@ -283,8 +281,7 @@ def test_batch_concatenate(
assert len(generations) == 2 assert len(generations) == 2
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert generations[0].request_id == default_bloom_batch.requests[0].id assert generations[0].request_id == default_bloom_batch.requests[0].id
assert ( assert (
@ -306,8 +303,7 @@ def test_batch_concatenate(
assert len(generations) == 1 assert len(generations) == 1
assert ( assert (
generations[0].generated_text.text generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
== "TestTestTestTestTestTestTestTestTestTest"
) )
assert ( assert (
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

View File

@ -9,6 +9,7 @@ from text_generation.utils import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
LocalEntryNotFoundError, LocalEntryNotFoundError,
FinishReason,
) )
@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
def test_stopping_criteria(): def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None) assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, "stop_sequence") assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
def test_stopping_criteria_eos(): def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5) criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, "eos_token") assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
def test_stopping_criteria_max(): def test_stopping_criteria_max():
@ -39,7 +40,7 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, "length") assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_weight_hub_files(): def test_weight_hub_files():

View File

@ -13,7 +13,7 @@ app = typer.Typer()
@app.command() @app.command()
def serve( def serve(
model_name: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
@ -46,16 +46,16 @@ def serve(
os.getenv("MASTER_PORT", None) is not None os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True" ), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, revision, sharded, quantize, uds_path) server.serve(model_id, revision, sharded, quantize, uds_path)
@app.command() @app.command()
def download_weights( def download_weights(
model_name: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
extension: str = ".safetensors", extension: str = ".safetensors",
): ):
utils.download_weights(model_name, revision, extension) utils.download_weights(model_id, revision, extension)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True
def get_model( def get_model(
model_name: str, revision: Optional[str], sharded: bool, quantize: bool model_id: str, revision: Optional[str], sharded: bool, quantize: bool
) -> Model: ) -> Model:
config = AutoConfig.from_pretrained(model_name, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision)
if config.model_type == "bloom": if config.model_type == "bloom":
if sharded: if sharded:
return BLOOMSharded(model_name, revision, quantize=quantize) return BLOOMSharded(model_id, revision, quantize=quantize)
else: else:
return BLOOM(model_name, revision, quantize=quantize) return BLOOM(model_id, revision, quantize=quantize)
elif config.model_type == "gpt_neox": elif config.model_type == "gpt_neox":
if sharded: if sharded:
return GPTNeoxSharded(model_name, revision, quantize=quantize) return GPTNeoxSharded(model_id, revision, quantize=quantize)
else: else:
return GPTNeox(model_name, revision, quantize=quantize) return GPTNeox(model_id, revision, quantize=quantize)
elif model_name.startswith("facebook/galactica"): elif model_id.startswith("facebook/galactica"):
if sharded: if sharded:
return GalacticaSharded(model_name, revision, quantize=quantize) return GalacticaSharded(model_id, revision, quantize=quantize)
else: else:
return Galactica(model_name, revision, quantize=quantize) return Galactica(model_id, revision, quantize=quantize)
elif "santacoder" in model_name: elif "santacoder" in model_id:
return SantaCoder(model_name, revision, quantize) return SantaCoder(model_id, revision, quantize)
else: else:
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
try: try:
return CausalLM(model_name, revision, quantize=quantize) return CausalLM(model_id, revision, quantize=quantize)
except Exception: except Exception:
return Seq2SeqLM(model_name, revision, quantize=quantize) return Seq2SeqLM(model_id, revision, quantize=quantize)

View File

@ -57,10 +57,10 @@ class BLOOM(CausalLM):
class BLOOMSharded(BLOOM): class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_name.startswith("bigscience/bloom"): if not model_id.startswith("bigscience/bloom"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, revision=revision, slow_but_exact=False, tp_parallel=True model_id, revision=revision, slow_but_exact=False, tp_parallel=True
) )
config.pad_token_id = 3 config.pad_token_id = 3
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "bigscience/bloom-560m": if self.master and model_id == "bigscience/bloom-560m":
download_weights(model_name, revision=revision, extension=".safetensors") download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files( filenames = weight_files(model_id, revision=revision, extension=".safetensors")
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")

View File

@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
class CausalLM(Model): class CausalLM(Model):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -244,10 +244,10 @@ class CausalLM(Model):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,

View File

@ -149,10 +149,10 @@ class Galactica(CausalLM):
class GalacticaSharded(Galactica): class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
if not model_name.startswith("facebook/galactica"): if not model_id.startswith("facebook/galactica"):
raise ValueError(f"Model {model_name} is not supported") raise ValueError(f"Model {model_id} is not supported")
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -164,22 +164,20 @@ class GalacticaSharded(Galactica):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, revision=revision, tp_parallel=True model_id, revision=revision, tp_parallel=True
) )
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# Only download weights for small models # Only download weights for small models
if self.master and model_name == "facebook/galactica-125m": if self.master and model_id == "facebook/galactica-125m":
download_weights(model_name, revision=revision, extension=".safetensors") download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files( filenames = weight_files(model_id, revision=revision, extension=".safetensors")
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")

View File

@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
class GPTNeoxSharded(GPTNeox): class GPTNeoxSharded(GPTNeox):
def __init__( def __init__(
self, model_name: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = self.rank == 0
@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_name, revision=revision, tp_parallel=True model_id, revision=revision, tp_parallel=True
) )
# Only master download weights # Only master download weights
if self.master: if self.master:
download_weights(model_name, revision=revision, extension=".safetensors") download_weights(model_id, revision=revision, extension=".safetensors")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files( filenames = weight_files(model_id, revision=revision, extension=".safetensors")
model_name, revision=revision, extension=".safetensors"
)
if not filenames: if not filenames:
raise ValueError("No safetensors weights found") raise ValueError("No safetensors weights found")

View File

@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
class SantaCoder(CausalLM): class SantaCoder(CausalLM):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
@ -43,7 +43,7 @@ class SantaCoder(CausalLM):
self.model = ( self.model = (
AutoModelForCausalLM.from_pretrained( AutoModelForCausalLM.from_pretrained(
model_name, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize, load_in_8bit=quantize,

View File

@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
class Seq2SeqLM(Model): class Seq2SeqLM(Model):
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
@ -301,14 +301,14 @@ class Seq2SeqLM(Model):
dtype = torch.float32 dtype = torch.float32
self.model = AutoModelForSeq2SeqLM.from_pretrained( self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize, load_in_8bit=quantize,
).eval() ).eval()
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_name, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left"
) )
tokenizer.bos_token_id = self.model.config.decoder_start_token_id tokenizer.bos_token_id = self.model.config.decoder_start_token_id

View File

@ -7,6 +7,7 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):
@ -38,7 +39,7 @@ class Batch(ABC):
class GeneratedText: class GeneratedText:
text: str text: str
generated_tokens: int generated_tokens: int
finish_reason: str finish_reason: FinishReason
seed: Optional[int] seed: Optional[int]
def to_pb(self) -> generate_pb2.GeneratedText: def to_pb(self) -> generate_pb2.GeneratedText:

View File

@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve( def serve(
model_name: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool, sharded: bool,
quantize: bool, quantize: bool,
uds_path: Path, uds_path: Path,
): ):
async def serve_inner( async def serve_inner(
model_name: str, model_id: str,
revision: Optional[str], revision: Optional[str],
sharded: bool = False, sharded: bool = False,
quantize: bool = False, quantize: bool = False,
@ -89,7 +89,7 @@ def serve(
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
model = get_model(model_name, revision, sharded, quantize) model = get_model(model_id, revision, sharded, quantize)
server = aio.server(interceptors=[ExceptionInterceptor()]) server = aio.server(interceptors=[ExceptionInterceptor()])
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
@ -109,4 +109,4 @@ def serve(
logger.info("Signal received. Shutting down") logger.info("Signal received. Shutting down")
await server.stop(0) await server.stop(0)
asyncio.run(serve_inner(model_name, revision, sharded, quantize)) asyncio.run(serve_inner(model_id, revision, sharded, quantize))

View File

@ -24,9 +24,11 @@ from transformers.generation.logits_process import (
) )
from text_generation.pb import generate_pb2 from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling: class Sampling:
def __init__(self, seed: int, device: str = "cpu"): def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device) self.generator = torch.Generator(device)
@ -129,15 +131,15 @@ class StoppingCriteria:
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]: def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1 self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens: if self.current_tokens >= self.max_new_tokens:
return True, "length" return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id: if last_token == self.eos_token_id:
return True, "eos_token" return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output self.current_output += last_output
for stop_sequence_criteria in self.stop_sequence_criterias: for stop_sequence_criteria in self.stop_sequence_criterias:
if stop_sequence_criteria(self.current_output): if stop_sequence_criteria(self.current_output):
return True, "stop_sequence" return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
return False, None return False, None
@ -180,20 +182,20 @@ def initialize_torch_distributed():
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
def weight_hub_files(model_name, revision=None, extension=".safetensors"): def weight_hub_files(model_id, revision=None, extension=".safetensors"):
"""Get the safetensors filenames on the hub""" """Get the safetensors filenames on the hub"""
api = HfApi() api = HfApi()
info = api.model_info(model_name, revision=revision) info = api.model_info(model_id, revision=revision)
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
return filenames return filenames
def try_to_load_from_cache(model_name, revision, filename): def try_to_load_from_cache(model_id, revision, filename):
"""Try to load a file from the Hugging Face cache""" """Try to load a file from the Hugging Face cache"""
if revision is None: if revision is None:
revision = "main" revision = "main"
object_id = model_name.replace("/", "--") object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
if not repo_cache.is_dir(): if not repo_cache.is_dir():
@ -228,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename):
return str(cached_file) if cached_file.is_file() else None return str(cached_file) if cached_file.is_file() else None
def weight_files(model_name, revision=None, extension=".safetensors"): def weight_files(model_id, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames""" """Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None: if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension) filenames = weight_hub_files(model_id, revision, extension)
files = [] files = []
for filename in filenames: for filename in filenames:
cache_file = try_to_load_from_cache( cache_file = try_to_load_from_cache(
model_name, revision=revision, filename=filename model_id, revision=revision, filename=filename
) )
if cache_file is None: if cache_file is None:
raise LocalEntryNotFoundError( raise LocalEntryNotFoundError(
f"File {filename} of model {model_name} not found in " f"File {filename} of model {model_id} not found in "
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
f"Please run `text-generation-server download-weights {model_name}` first." f"Please run `text-generation-server download-weights {model_id}` first."
) )
files.append(cache_file) files.append(cache_file)
return files return files
def download_weights(model_name, revision=None, extension=".safetensors"): def download_weights(model_id, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub""" """Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None: if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}")) return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension) filenames = weight_hub_files(model_id, revision, extension)
download_function = partial( download_function = partial(
hf_hub_download, hf_hub_download,
repo_id=model_name, repo_id=model_id,
local_files_only=False, local_files_only=False,
) )