feat(router): refactor API and add openAPI schemas (#53)
This commit is contained in:
parent
b1482d9048
commit
20c3c5940c
|
@ -5,6 +5,8 @@ on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
@ -43,6 +45,8 @@ jobs:
|
||||||
ghcr.io/huggingface/text-generation-inference
|
ghcr.io/huggingface/text-generation-inference
|
||||||
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
registry.internal.huggingface.tech/api-inference/community/text-generation-inference
|
||||||
tags: |
|
tags: |
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
type=raw,value=latest,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}
|
||||||
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
|
type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
|
|
|
@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.5.17"
|
version = "0.6.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
|
checksum = "e5694b64066a2459918d8074c2ce0d5a88f409431994c2356617c8ae0c4721fc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
@ -101,8 +101,10 @@ dependencies = [
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
@ -114,9 +116,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.2.9"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
|
checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
@ -124,6 +126,7 @@ dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"mime",
|
"mime",
|
||||||
|
"rustversion",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
@ -207,7 +210,7 @@ dependencies = [
|
||||||
"tar",
|
"tar",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"zip",
|
"zip 0.5.13",
|
||||||
"zip-extensions",
|
"zip-extensions",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -465,6 +468,15 @@ dependencies = [
|
||||||
"dirs-sys",
|
"dirs-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dirs"
|
||||||
|
version = "4.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
|
||||||
|
dependencies = [
|
||||||
|
"dirs-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs-sys"
|
name = "dirs-sys"
|
||||||
version = "0.3.7"
|
version = "0.3.7"
|
||||||
|
@ -867,6 +879,7 @@ checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"hashbrown",
|
"hashbrown",
|
||||||
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -999,9 +1012,9 @@ checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.5.0"
|
version = "0.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
|
@ -1024,6 +1037,16 @@ version = "0.3.16"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
|
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mime_guess"
|
||||||
|
version = "2.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
|
||||||
|
dependencies = [
|
||||||
|
"mime",
|
||||||
|
"unicase",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -1552,12 +1575,62 @@ dependencies = [
|
||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-embed"
|
||||||
|
version = "6.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "283ffe2f866869428c92e0d61c2f35dfb4355293cdfdc48f49e895c15f1333d1"
|
||||||
|
dependencies = [
|
||||||
|
"rust-embed-impl",
|
||||||
|
"rust-embed-utils",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-embed-impl"
|
||||||
|
version = "6.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "31ab23d42d71fb9be1b643fe6765d292c5e14d46912d13f3ae2815ca048ea04d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"rust-embed-utils",
|
||||||
|
"shellexpand",
|
||||||
|
"syn",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-embed-utils"
|
||||||
|
version = "7.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c1669d81dfabd1b5f8e2856b8bbe146c6192b0ba22162edc738ac0a5de18f054"
|
||||||
|
dependencies = [
|
||||||
|
"sha2",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustversion"
|
||||||
|
version = "1.0.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.11"
|
version = "1.0.11"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
|
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "same-file"
|
||||||
|
version = "1.0.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "schannel"
|
name = "schannel"
|
||||||
version = "0.1.20"
|
version = "0.1.20"
|
||||||
|
@ -1628,6 +1701,15 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26b04f22b563c91331a10074bda3dd5492e3cc39d56bd557e91c0af42b6c7341"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
|
@ -1660,6 +1742,15 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shellexpand"
|
||||||
|
version = "2.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7ccc8076840c4da029af4f87e4e8daeb0fca6b87bbb02e10cb60b791450e11e4"
|
||||||
|
dependencies = [
|
||||||
|
"dirs 4.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
|
@ -1797,7 +1888,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures",
|
"futures",
|
||||||
"prost",
|
"prost",
|
||||||
|
@ -1812,7 +1903,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.0.22",
|
"clap 4.0.22",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
|
@ -1827,7 +1918,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"axum",
|
"axum",
|
||||||
|
@ -1845,6 +1936,8 @@ dependencies = [
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
"utoipa",
|
||||||
|
"utoipa-swagger-ui",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1921,7 +2014,7 @@ dependencies = [
|
||||||
"cached-path",
|
"cached-path",
|
||||||
"clap 2.34.0",
|
"clap 2.34.0",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"dirs",
|
"dirs 3.0.2",
|
||||||
"esaxx-rs",
|
"esaxx-rs",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
"indicatif 0.15.0",
|
"indicatif 0.15.0",
|
||||||
|
@ -2234,6 +2327,15 @@ version = "1.15.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicase"
|
||||||
|
version = "2.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6"
|
||||||
|
dependencies = [
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-bidi"
|
name = "unicode-bidi"
|
||||||
version = "0.3.8"
|
version = "0.3.8"
|
||||||
|
@ -2293,6 +2395,46 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa"
|
||||||
|
version = "3.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3920fa753064b1be7842bea26175ffa0dfc4a8f30bcb52b8ff03fddf8889914c"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"utoipa-gen",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa-gen"
|
||||||
|
version = "3.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "720298fac6efca20df9e457e67a1eab41a20d1c3101380b5c4dca1ca60ae0062"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utoipa-swagger-ui"
|
||||||
|
version = "3.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ae3d4f4da6408f0f20ff58196ed619c94306ab32635aeca3d3fa0768c0bd0de2"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"mime_guess",
|
||||||
|
"regex",
|
||||||
|
"rust-embed",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"utoipa",
|
||||||
|
"zip 0.6.4",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "valuable"
|
name = "valuable"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -2317,6 +2459,17 @@ version = "0.9.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "walkdir"
|
||||||
|
version = "2.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
|
||||||
|
dependencies = [
|
||||||
|
"same-file",
|
||||||
|
"winapi",
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "want"
|
name = "want"
|
||||||
version = "0.3.0"
|
version = "0.3.0"
|
||||||
|
@ -2589,11 +2742,23 @@ dependencies = [
|
||||||
"time",
|
"time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zip"
|
||||||
|
version = "0.6.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"crc32fast",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"flate2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zip-extensions"
|
name = "zip-extensions"
|
||||||
version = "0.6.1"
|
version = "0.6.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"zip",
|
"zip 0.5.13",
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,9 +4,6 @@ members = [
|
||||||
"router/client",
|
"router/client",
|
||||||
"launcher"
|
"launcher"
|
||||||
]
|
]
|
||||||
exclude = [
|
|
||||||
"server/safetensors",
|
|
||||||
]
|
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = 1
|
debug = 1
|
||||||
|
|
10
Dockerfile
10
Dockerfile
|
@ -26,21 +26,18 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||||
ENV LANG=C.UTF-8 \
|
ENV LANG=C.UTF-8 \
|
||||||
LC_ALL=C.UTF-8 \
|
LC_ALL=C.UTF-8 \
|
||||||
DEBIAN_FRONTEND=noninteractive \
|
DEBIAN_FRONTEND=noninteractive \
|
||||||
MODEL_BASE_PATH=/data \
|
HUGGINGFACE_HUB_CACHE=/data \
|
||||||
MODEL_ID=bigscience/bloom-560m \
|
MODEL_ID=bigscience/bloom-560m \
|
||||||
QUANTIZE=false \
|
QUANTIZE=false \
|
||||||
NUM_GPUS=1 \
|
NUM_SHARD=1 \
|
||||||
SAFETENSORS_FAST_GPU=1 \
|
SAFETENSORS_FAST_GPU=1 \
|
||||||
PORT=80 \
|
PORT=80 \
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
|
||||||
NCCL_ASYNC_ERROR_HANDLING=1 \
|
NCCL_ASYNC_ERROR_HANDLING=1 \
|
||||||
CUDA_HOME=/usr/local/cuda \
|
CUDA_HOME=/usr/local/cuda \
|
||||||
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||||
CONDA_DEFAULT_ENV=text-generation \
|
CONDA_DEFAULT_ENV=text-generation \
|
||||||
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
|
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
|
||||||
|
|
||||||
SHELL ["/bin/bash", "-c"]
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
|
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN cd ~ && \
|
RUN cd ~ && \
|
||||||
|
@ -71,4 +68,5 @@ COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/loca
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
CMD HUGGINGFACE_HUB_CACHE=$MODEL_BASE_PATH text-generation-launcher --num-shard $NUM_GPUS --model-name $MODEL_ID --json-output
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
|
CMD ["--json-output"]
|
14
Makefile
14
Makefile
|
@ -15,17 +15,23 @@ server-dev:
|
||||||
router-dev:
|
router-dev:
|
||||||
cd router && cargo run
|
cd router && cargo run
|
||||||
|
|
||||||
|
integration-tests: install-router install-launcher
|
||||||
|
cargo test
|
||||||
|
|
||||||
|
python-tests:
|
||||||
|
cd server && pytest tests
|
||||||
|
|
||||||
run-bloom-560m:
|
run-bloom-560m:
|
||||||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2
|
||||||
|
|
||||||
run-bloom-560m-quantize:
|
run-bloom-560m-quantize:
|
||||||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
|
text-generation-launcher --model-id bigscience/bloom-560m --num-shard 2 --quantize
|
||||||
|
|
||||||
download-bloom:
|
download-bloom:
|
||||||
text-generation-server download-weights bigscience/bloom
|
text-generation-server download-weights bigscience/bloom
|
||||||
|
|
||||||
run-bloom:
|
run-bloom:
|
||||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8
|
||||||
|
|
||||||
run-bloom-quantize:
|
run-bloom-quantize:
|
||||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8 --quantize
|
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize
|
113
README.md
113
README.md
|
@ -1,16 +1,43 @@
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
# Text Generation Inference
|
# Text Generation Inference
|
||||||
|
|
||||||
<div align="center">
|
<a href="https://github.com/huggingface/text-generation-inference">
|
||||||
|
<img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/huggingface/text-generation-inference?style=social">
|
||||||
|
</a>
|
||||||
|
<a href="https://github.com/huggingface/text-generation-inference/blob/main/LICENSE">
|
||||||
|
<img alt="License" src="https://img.shields.io/github/license/huggingface/text-generation-inference">
|
||||||
|
</a>
|
||||||
|
<a href="https://huggingface.github.io/text-generation-inference">
|
||||||
|
<img alt="Swagger API documentation" src="https://img.shields.io/badge/API-Swagger-informational">
|
||||||
|
</a>
|
||||||
|
|
||||||
![architecture](assets/architecture.jpg)
|
![architecture](assets/architecture.jpg)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
A Rust and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co)
|
||||||
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
to power LLMs api-inference widgets.
|
||||||
|
|
||||||
|
## Table of contents
|
||||||
|
|
||||||
|
- [Features](#features)
|
||||||
|
- [Officially Supported Models](#officially-supported-models)
|
||||||
|
- [Get Started](#get-started)
|
||||||
|
- [Docker](#docker)
|
||||||
|
- [Local Install](#local-install)
|
||||||
|
- [OpenAPI](#api-documentation)
|
||||||
|
- [CUDA Kernels](#cuda-kernels)
|
||||||
|
- [Run BLOOM](#run-bloom)
|
||||||
|
- [Download](#download)
|
||||||
|
- [Run](#run)
|
||||||
|
- [Quantization](#quantization)
|
||||||
|
- [Develop](#develop)
|
||||||
|
- [Testing](#testing)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
- Token streaming using Server Side Events (SSE)
|
||||||
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
- [Dynamic batching of incoming requests](https://github.com/huggingface/text-generation-inference/blob/main/router/src/batcher.rs#L88) for increased total throughput
|
||||||
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
|
@ -36,30 +63,63 @@ or
|
||||||
|
|
||||||
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
|
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
|
||||||
|
|
||||||
## Load Tests for BLOOM
|
## Get started
|
||||||
|
|
||||||
See `k6/load_test.js`
|
### Docker
|
||||||
|
|
||||||
| | avg | min | med | max | p(90) | p(95) | RPS |
|
The easiest way of getting started is using the official Docker container:
|
||||||
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
|
|
||||||
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
|
||||||
| New batching logic | **5.44s** | **959.53ms** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make install
|
model=bigscience/bloom-560m
|
||||||
|
num_shard=2
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --num-shard $num_shard
|
||||||
```
|
```
|
||||||
|
|
||||||
## Run
|
You can then query the model using either the `/generate` or `/generate_stream` routes:
|
||||||
|
|
||||||
### BLOOM 560-m
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
curl 127.0.0.1:8080/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl 127.0.0.1:8080/generate_stream \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
|
```
|
||||||
|
|
||||||
|
To use GPUs, you will need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html).
|
||||||
|
|
||||||
|
### API documentation
|
||||||
|
|
||||||
|
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
|
||||||
|
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
|
||||||
|
|
||||||
|
### Local install
|
||||||
|
|
||||||
|
You can also opt to install `text-generation-inference` locally. You will need to have cargo and Python installed on your
|
||||||
|
machine
|
||||||
|
|
||||||
|
```shell
|
||||||
|
BUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels
|
||||||
make run-bloom-560m
|
make run-bloom-560m
|
||||||
```
|
```
|
||||||
|
|
||||||
### BLOOM
|
### CUDA Kernels
|
||||||
|
|
||||||
|
The custom CUDA kernels are only tested on NVIDIA A100s. If you have any installation or runtime issues, you can remove
|
||||||
|
the kernels by using the `BUILD_EXTENSIONS=False` environment variable.
|
||||||
|
|
||||||
|
Be aware that the official Docker image has them enabled by default.
|
||||||
|
|
||||||
|
## Run BLOOM
|
||||||
|
|
||||||
|
### Download
|
||||||
|
|
||||||
First you need to download the weights:
|
First you need to download the weights:
|
||||||
|
|
||||||
|
@ -67,29 +127,30 @@ First you need to download the weights:
|
||||||
make download-bloom
|
make download-bloom
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-bloom # Requires 8xA100 80GB
|
make run-bloom # Requires 8xA100 80GB
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Quantization
|
||||||
|
|
||||||
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
You can also quantize the weights with bitsandbytes to reduce the VRAM requirement:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make run-bloom-quantize # Requires 8xA100 40GB
|
make run-bloom-quantize # Requires 8xA100 40GB
|
||||||
```
|
```
|
||||||
|
|
||||||
## Test
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl 127.0.0.1:3000/generate \
|
|
||||||
-v \
|
|
||||||
-X POST \
|
|
||||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
|
||||||
-H 'Content-Type: application/json'
|
|
||||||
```
|
|
||||||
|
|
||||||
## Develop
|
## Develop
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
make server-dev
|
make server-dev
|
||||||
make router-dev
|
make router-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make python-tests
|
||||||
|
make integration-tests
|
||||||
```
|
```
|
|
@ -4,9 +4,9 @@ endpoint_name: bloom-inference
|
||||||
model: azureml:bloom:1
|
model: azureml:bloom:1
|
||||||
model_mount_path: /var/azureml-model
|
model_mount_path: /var/azureml-model
|
||||||
environment_variables:
|
environment_variables:
|
||||||
MODEL_BASE_PATH: /var/azureml-model/bloom
|
HUGGINGFACE_HUB_CACHE: /var/azureml-model/bloom
|
||||||
MODEL_ID: bigscience/bloom
|
MODEL_ID: bigscience/bloom
|
||||||
NUM_GPUS: 8
|
NUM_SHARD: 8
|
||||||
environment:
|
environment:
|
||||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
|
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.3.1
|
||||||
inference_config:
|
inference_config:
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 132 KiB After Width: | Height: | Size: 334 KiB |
|
@ -0,0 +1,30 @@
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<!-- Load the latest Swagger UI code and style from npm using unpkg.com -->
|
||||||
|
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js"></script>
|
||||||
|
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"/>
|
||||||
|
<title>Text Generation Inference API</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="swagger-ui"></div> <!-- Div to hold the UI component -->
|
||||||
|
<script>
|
||||||
|
window.onload = function () {
|
||||||
|
// Begin Swagger UI call region
|
||||||
|
const ui = SwaggerUIBundle({
|
||||||
|
url: "openapi.json", //Location of Open API spec in the repo
|
||||||
|
dom_id: '#swagger-ui',
|
||||||
|
deepLinking: true,
|
||||||
|
supportedSubmitMethods: [],
|
||||||
|
presets: [
|
||||||
|
SwaggerUIBundle.presets.apis,
|
||||||
|
SwaggerUIBundle.SwaggerUIStandalonePreset
|
||||||
|
],
|
||||||
|
plugins: [
|
||||||
|
SwaggerUIBundle.plugins.DownloadUrl
|
||||||
|
],
|
||||||
|
})
|
||||||
|
window.ui = ui
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -0,0 +1,446 @@
|
||||||
|
{
|
||||||
|
"openapi": "3.0.3",
|
||||||
|
"info": {
|
||||||
|
"title": "Text Generation Inference",
|
||||||
|
"description": "Text Generation Webserver",
|
||||||
|
"contact": {
|
||||||
|
"name": "Olivier Dehaene",
|
||||||
|
"email": ""
|
||||||
|
},
|
||||||
|
"license": {
|
||||||
|
"name": "Apache 2.0",
|
||||||
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
|
},
|
||||||
|
"version": "0.2.0"
|
||||||
|
},
|
||||||
|
"paths": {
|
||||||
|
"/generate": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate tokens",
|
||||||
|
"description": "Generate tokens",
|
||||||
|
"operationId": "generate",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/GenerateRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Generated Text",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/GenerateResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"deprecated": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/generate_stream": {
|
||||||
|
"post": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Generate a stream of token using Server Side Events",
|
||||||
|
"description": "Generate a stream of token using Server Side Events",
|
||||||
|
"operationId": "generate_stream",
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/GenerateRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Generated Text",
|
||||||
|
"content": {
|
||||||
|
"text/event-stream ": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/StreamResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"422": {
|
||||||
|
"description": "Input validation error",
|
||||||
|
"content": {
|
||||||
|
"text/event-stream ": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Input validation error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"424": {
|
||||||
|
"description": "Generation Error",
|
||||||
|
"content": {
|
||||||
|
"text/event-stream ": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Request failed during generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"description": "Model is overloaded",
|
||||||
|
"content": {
|
||||||
|
"text/event-stream ": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Model is overloaded"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"description": "Incomplete generation",
|
||||||
|
"content": {
|
||||||
|
"text/event-stream ": {
|
||||||
|
"schema": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"example": {
|
||||||
|
"error": "Incomplete generation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"deprecated": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"components": {
|
||||||
|
"schemas": {
|
||||||
|
"Details": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"finish_reason",
|
||||||
|
"generated_tokens"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"finish_reason": {
|
||||||
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
|
},
|
||||||
|
"generated_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 1
|
||||||
|
},
|
||||||
|
"prefill": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Token"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 42
|
||||||
|
},
|
||||||
|
"tokens": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ErrorResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"error"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"error": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"FinishReason": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"length",
|
||||||
|
"eos_token",
|
||||||
|
"stop_sequence"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"GenerateParameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"details": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": "true"
|
||||||
|
},
|
||||||
|
"do_sample": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": "false",
|
||||||
|
"example": true
|
||||||
|
},
|
||||||
|
"max_new_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"default": "20",
|
||||||
|
"exclusiveMaximum": 512.0,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
|
"repetition_penalty": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"default": "null",
|
||||||
|
"example": 1.03,
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64"
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"example": [
|
||||||
|
"photographer"
|
||||||
|
],
|
||||||
|
"maxItems": 4
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"default": "null",
|
||||||
|
"example": 0.5,
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"default": "null",
|
||||||
|
"example": 10,
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
|
"top_p": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"default": "null",
|
||||||
|
"example": 0.95,
|
||||||
|
"nullable": true,
|
||||||
|
"maximum": 1.0,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"GenerateRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"inputs"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"inputs": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "My name is Olivier and I"
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"$ref": "#/components/schemas/GenerateParameters"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"GenerateResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"generated_text"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"details": {
|
||||||
|
"$ref": "#/components/schemas/Details"
|
||||||
|
},
|
||||||
|
"generated_text": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"StreamDetails": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"finish_reason",
|
||||||
|
"generated_tokens"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"finish_reason": {
|
||||||
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
|
},
|
||||||
|
"generated_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 1
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 42
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"StreamResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"token"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"details": {
|
||||||
|
"$ref": "#/components/schemas/StreamDetails"
|
||||||
|
},
|
||||||
|
"generated_text": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "null",
|
||||||
|
"example": "test",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"token": {
|
||||||
|
"$ref": "#/components/schemas/Token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"Token": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"text",
|
||||||
|
"logprob"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 0
|
||||||
|
},
|
||||||
|
"logprob": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"example": -0.34,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "test"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
{
|
||||||
|
"name": "Text Generation Inference",
|
||||||
|
"description": "Hugging Face Text Generation Inference API"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
description = "Text Generation Launcher"
|
description = "Text Generation Launcher"
|
||||||
|
|
|
@ -19,7 +19,7 @@ use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
model_name: String,
|
model_id: String,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
|
@ -49,7 +49,7 @@ struct Args {
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
model_name,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
num_shard,
|
num_shard,
|
||||||
quantize,
|
quantize,
|
||||||
|
@ -92,7 +92,7 @@ fn main() -> ExitCode {
|
||||||
|
|
||||||
// Start shard processes
|
// Start shard processes
|
||||||
for rank in 0..num_shard {
|
for rank in 0..num_shard {
|
||||||
let model_name = model_name.clone();
|
let model_id = model_id.clone();
|
||||||
let revision = revision.clone();
|
let revision = revision.clone();
|
||||||
let uds_path = shard_uds_path.clone();
|
let uds_path = shard_uds_path.clone();
|
||||||
let master_addr = master_addr.clone();
|
let master_addr = master_addr.clone();
|
||||||
|
@ -101,7 +101,7 @@ fn main() -> ExitCode {
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_name,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize,
|
quantize,
|
||||||
uds_path,
|
uds_path,
|
||||||
|
@ -167,7 +167,7 @@ fn main() -> ExitCode {
|
||||||
"--master-shard-uds-path".to_string(),
|
"--master-shard-uds-path".to_string(),
|
||||||
format!("{}-0", shard_uds_path),
|
format!("{}-0", shard_uds_path),
|
||||||
"--tokenizer-name".to_string(),
|
"--tokenizer-name".to_string(),
|
||||||
model_name,
|
model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
if json_output {
|
if json_output {
|
||||||
|
@ -256,7 +256,7 @@ enum ShardStatus {
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn shard_manager(
|
fn shard_manager(
|
||||||
model_name: String,
|
model_id: String,
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
uds_path: String,
|
uds_path: String,
|
||||||
|
@ -278,7 +278,7 @@ fn shard_manager(
|
||||||
let mut shard_argv = vec![
|
let mut shard_argv = vec![
|
||||||
"text-generation-server".to_string(),
|
"text-generation-server".to_string(),
|
||||||
"serve".to_string(),
|
"serve".to_string(),
|
||||||
model_name,
|
model_id,
|
||||||
"--uds-path".to_string(),
|
"--uds-path".to_string(),
|
||||||
uds_path,
|
uds_path,
|
||||||
"--logger-level".to_string(),
|
"--logger-level".to_string(),
|
||||||
|
|
|
@ -1,123 +1,122 @@
|
||||||
[
|
{
|
||||||
{
|
"details": {
|
||||||
"details": {
|
"finish_reason": "length",
|
||||||
"finish_reason": "length",
|
"generated_tokens": 20,
|
||||||
"generated_tokens": 20,
|
"prefill": [
|
||||||
"prefill": [
|
{
|
||||||
[
|
"id": 10264,
|
||||||
10264,
|
"logprob": null,
|
||||||
"Test",
|
"text": "Test"
|
||||||
null
|
},
|
||||||
],
|
{
|
||||||
[
|
"id": 8821,
|
||||||
8821,
|
"logprob": -11.894989,
|
||||||
" request",
|
"text": " request"
|
||||||
-11.895094
|
}
|
||||||
]
|
],
|
||||||
],
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
[
|
{
|
||||||
17,
|
"id": 17,
|
||||||
".",
|
"logprob": -1.8267672,
|
||||||
-1.8267941
|
"text": "."
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
1587,
|
"id": 1587,
|
||||||
"get",
|
"logprob": -2.4674969,
|
||||||
-2.4674964
|
"text": "get"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
11,
|
"id": 11,
|
||||||
"(",
|
"logprob": -1.906001,
|
||||||
-1.9060438
|
"text": "("
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
5,
|
"id": 5,
|
||||||
"\"",
|
"logprob": -1.2279545,
|
||||||
-1.2279553
|
"text": "\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
4899,
|
"id": 4899,
|
||||||
"action",
|
"logprob": -4.170299,
|
||||||
-4.170306
|
"text": "action"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
5,
|
"id": 5,
|
||||||
"\"",
|
"logprob": -0.32478866,
|
||||||
-0.3247902
|
"text": "\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
12,
|
"id": 12,
|
||||||
")",
|
"logprob": -1.0773665,
|
||||||
-1.0773602
|
"text": ")"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
30,
|
"id": 30,
|
||||||
";",
|
"logprob": -0.27640742,
|
||||||
-0.27640444
|
"text": ";"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
837,
|
"id": 837,
|
||||||
"\n ",
|
"logprob": -1.6970354,
|
||||||
-1.6970599
|
"text": "\n "
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
1320,
|
"id": 1320,
|
||||||
" if",
|
"logprob": -1.4495516,
|
||||||
-1.4495552
|
"text": " if"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
375,
|
"id": 375,
|
||||||
" (",
|
"logprob": -0.23609057,
|
||||||
-0.2360998
|
"text": " ("
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
4899,
|
"id": 4899,
|
||||||
"action",
|
"logprob": -1.1916996,
|
||||||
-1.1916926
|
"text": "action"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
3535,
|
"id": 3535,
|
||||||
" ==",
|
"logprob": -0.8918753,
|
||||||
-0.8918663
|
"text": " =="
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
5109,
|
"id": 5109,
|
||||||
" null",
|
"logprob": -0.3933342,
|
||||||
-0.39334255
|
"text": " null"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
12,
|
"id": 12,
|
||||||
")",
|
"logprob": -0.43212673,
|
||||||
-0.4321134
|
"text": ")"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
731,
|
"id": 731,
|
||||||
" {",
|
"logprob": -0.17702064,
|
||||||
-0.17701954
|
"text": " {"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
1260,
|
"id": 1260,
|
||||||
"\n ",
|
"logprob": -0.07027565,
|
||||||
-0.07027287
|
"text": "\n "
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
10519,
|
"id": 10519,
|
||||||
" throw",
|
"logprob": -1.3915029,
|
||||||
-1.3915133
|
"text": " throw"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
2084,
|
"id": 2084,
|
||||||
" new",
|
"logprob": -0.04201372,
|
||||||
-0.042013377
|
"text": " new"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
150858,
|
"id": 150858,
|
||||||
" RuntimeException",
|
"logprob": -1.7329919,
|
||||||
-1.7330077
|
"text": " RuntimeException"
|
||||||
]
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||||
}
|
}
|
||||||
]
|
|
|
@ -9,11 +9,18 @@ use std::thread::sleep;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use subprocess::{Popen, PopenConfig, Redirection};
|
use subprocess::{Popen, PopenConfig, Redirection};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Token {
|
||||||
|
id: u32,
|
||||||
|
text: String,
|
||||||
|
logprob: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Details {
|
struct Details {
|
||||||
finish_reason: String,
|
finish_reason: String,
|
||||||
generated_tokens: u32,
|
generated_tokens: u32,
|
||||||
tokens: Vec<(u32, String, Option<f32>)>,
|
tokens: Vec<Token>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -22,11 +29,11 @@ struct GeneratedText {
|
||||||
details: Details,
|
details: Details,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
|
||||||
let argv = vec![
|
let argv = vec![
|
||||||
"text-generation-launcher".to_string(),
|
"text-generation-launcher".to_string(),
|
||||||
"--model-name".to_string(),
|
"--model-id".to_string(),
|
||||||
model_name.clone(),
|
model_id.clone(),
|
||||||
"--num-shard".to_string(),
|
"--num-shard".to_string(),
|
||||||
num_shard.to_string(),
|
num_shard.to_string(),
|
||||||
"--port".to_string(),
|
"--port".to_string(),
|
||||||
|
@ -68,16 +75,16 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
|
||||||
|
|
||||||
launcher.terminate().unwrap();
|
launcher.terminate().unwrap();
|
||||||
launcher.wait().unwrap();
|
launcher.wait().unwrap();
|
||||||
panic!("failed to launch {}", model_name)
|
panic!("failed to launch {}", model_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_model(
|
fn test_model(
|
||||||
model_name: String,
|
model_id: String,
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
port: usize,
|
port: usize,
|
||||||
master_port: usize,
|
master_port: usize,
|
||||||
) -> GeneratedText {
|
) -> GeneratedText {
|
||||||
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
|
let mut launcher = start_launcher(model_id, num_shard, port, master_port);
|
||||||
|
|
||||||
let data = r#"
|
let data = r#"
|
||||||
{
|
{
|
||||||
|
@ -109,8 +116,8 @@ fn read_json(name: &str) -> GeneratedText {
|
||||||
let file = File::open(d).unwrap();
|
let file = File::open(d).unwrap();
|
||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
|
|
||||||
let mut results: Vec<GeneratedText> = serde_json::from_reader(reader).unwrap();
|
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
|
||||||
results.pop().unwrap()
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||||
|
@ -127,13 +134,13 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(expected.details.tokens.into_iter())
|
.zip(expected.details.tokens.into_iter())
|
||||||
{
|
{
|
||||||
assert_eq!(token.0, expected_token.0);
|
assert_eq!(token.id, expected_token.id);
|
||||||
assert_eq!(token.1, expected_token.1);
|
assert_eq!(token.text, expected_token.text);
|
||||||
if let Some(logprob) = token.2 {
|
if let Some(logprob) = token.logprob {
|
||||||
let expected_logprob = expected_token.2.unwrap();
|
let expected_logprob = expected_token.logprob.unwrap();
|
||||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||||
} else {
|
} else {
|
||||||
assert_eq!(token.2, expected_token.2);
|
assert_eq!(token.logprob, expected_token.logprob);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,118 +1,117 @@
|
||||||
[
|
{
|
||||||
{
|
"details": {
|
||||||
"details": {
|
"finish_reason": "length",
|
||||||
"finish_reason": "length",
|
"generated_tokens": 20,
|
||||||
"generated_tokens": 20,
|
"prefill": [
|
||||||
"prefill": [
|
{
|
||||||
[
|
"id": 0,
|
||||||
0,
|
"logprob": null,
|
||||||
"<pad>",
|
"text": "<pad>"
|
||||||
null
|
}
|
||||||
]
|
],
|
||||||
],
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -1.3656927,
|
||||||
-1.3656927
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
215100,
|
"id": 215100,
|
||||||
"\"\"\"",
|
"logprob": -2.6551573,
|
||||||
-2.6551573
|
"text": "\"\"\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
46138,
|
"id": 46138,
|
||||||
"Test",
|
"logprob": -1.8059857,
|
||||||
-1.8059857
|
"text": "Test"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
287,
|
"id": 287,
|
||||||
"the",
|
"logprob": -1.2102449,
|
||||||
-1.2102449
|
"text": "the"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -1.6057279,
|
||||||
-1.6057279
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
49076,
|
"id": 49076,
|
||||||
"contents",
|
"logprob": -3.6060903,
|
||||||
-3.6060903
|
"text": "contents"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
304,
|
"id": 304,
|
||||||
"of",
|
"logprob": -0.5270343,
|
||||||
-0.5270343
|
"text": "of"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
287,
|
"id": 287,
|
||||||
"the",
|
"logprob": -0.62522805,
|
||||||
-0.62522805
|
"text": "the"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -1.4069618,
|
||||||
-1.4069618
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
49076,
|
"id": 49076,
|
||||||
"contents",
|
"logprob": -2.621994,
|
||||||
-2.621994
|
"text": "contents"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
304,
|
"id": 304,
|
||||||
"of",
|
"logprob": -1.3172221,
|
||||||
-1.3172221
|
"text": "of"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
287,
|
"id": 287,
|
||||||
"the",
|
"logprob": -0.3501925,
|
||||||
-0.3501925
|
"text": "the"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -0.7219573,
|
||||||
-0.7219573
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
49076,
|
"id": 49076,
|
||||||
"contents",
|
"logprob": -1.0494149,
|
||||||
-1.0494149
|
"text": "contents"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
260,
|
"id": 260,
|
||||||
".",
|
"logprob": -1.0803378,
|
||||||
-1.0803378
|
"text": "."
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
259,
|
"id": 259,
|
||||||
"",
|
"logprob": -0.32933083,
|
||||||
-0.32933083
|
"text": ""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
215100,
|
"id": 215100,
|
||||||
"\"\"\"",
|
"logprob": -0.11268901,
|
||||||
-0.11268901
|
"text": "\"\"\""
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
2978,
|
"id": 2978,
|
||||||
"test",
|
"logprob": -1.5846587,
|
||||||
-1.5846587
|
"text": "test"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
290,
|
"id": 290,
|
||||||
"_",
|
"logprob": -0.49796978,
|
||||||
-0.49796978
|
"text": "_"
|
||||||
],
|
},
|
||||||
[
|
{
|
||||||
4125,
|
"id": 4125,
|
||||||
"test",
|
"logprob": -2.0026445,
|
||||||
-2.0026445
|
"text": "test"
|
||||||
]
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
||||||
}
|
}
|
||||||
]
|
|
|
@ -71,13 +71,19 @@ message Batch {
|
||||||
uint32 size = 3;
|
uint32 size = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum FinishReason {
|
||||||
|
FINISH_REASON_LENGTH = 0;
|
||||||
|
FINISH_REASON_EOS_TOKEN = 1;
|
||||||
|
FINISH_REASON_STOP_SEQUENCE = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message GeneratedText {
|
message GeneratedText {
|
||||||
/// Output
|
/// Output
|
||||||
string text = 1;
|
string text = 1;
|
||||||
/// Number of generated tokens
|
/// Number of generated tokens
|
||||||
uint32 generated_tokens = 2;
|
uint32 generated_tokens = 2;
|
||||||
/// Finish reason
|
/// Finish reason
|
||||||
string finish_reason = 3;
|
FinishReason finish_reason = 3;
|
||||||
/// Seed
|
/// Seed
|
||||||
optional uint64 seed = 4;
|
optional uint64 seed = 4;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
description = "Text Generation Webserver"
|
description = "Text Generation Webserver"
|
||||||
|
@ -14,7 +14,7 @@ path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-stream = "0.3.3"
|
async-stream = "0.3.3"
|
||||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
axum = { version = "0.6.4", features = ["json"] }
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
|
@ -29,4 +29,6 @@ tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot"
|
||||||
tokio-stream = "0.1.11"
|
tokio-stream = "0.1.11"
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
tracing-subscriber = { version = "0.3.15", features = ["json"] }
|
||||||
|
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||||
|
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
|
@ -7,8 +7,8 @@ mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Request,
|
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||||
StoppingCriteriaParameters,
|
Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
|
@ -127,7 +127,7 @@ impl Infer {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(tokens.logprobs.into_iter())
|
.zip(tokens.logprobs.into_iter())
|
||||||
.zip(tokens.texts.into_iter())
|
.zip(tokens.texts.into_iter())
|
||||||
.map(|((id, logprob), text)| Token(id, text, logprob))
|
.map(|((id, logprob), text)| Token { id, text, logprob })
|
||||||
.collect();
|
.collect();
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
|
@ -282,11 +282,11 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
let token = Token(
|
let token = Token {
|
||||||
generation.token_id,
|
id: generation.token_id,
|
||||||
generation.token_text,
|
text: generation.token_text,
|
||||||
generation.token_logprob,
|
logprob: generation.token_logprob,
|
||||||
);
|
};
|
||||||
|
|
||||||
if let Some(generated_text) = generation.generated_text {
|
if let Some(generated_text) = generation.generated_text {
|
||||||
// Remove entry as this is the last message
|
// Remove entry as this is the last message
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
|
|
||||||
mod infer;
|
mod infer;
|
||||||
mod queue;
|
mod queue;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
@ -8,45 +7,55 @@ mod validation;
|
||||||
use infer::Infer;
|
use infer::Infer;
|
||||||
use queue::{Entry, Queue};
|
use queue::{Entry, Queue};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use utoipa::ToSchema;
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
#[serde(default = "default_temperature")]
|
#[serde(default)]
|
||||||
pub temperature: f32,
|
#[schema(
|
||||||
#[serde(default = "default_repetition_penalty")]
|
exclusive_minimum = 0.0,
|
||||||
pub repetition_penalty: f32,
|
nullable = true,
|
||||||
#[serde(default = "default_top_k")]
|
default = "null",
|
||||||
pub top_k: i32,
|
example = 0.5
|
||||||
#[serde(default = "default_top_p")]
|
)]
|
||||||
pub top_p: f32,
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = 0.0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = 1.03
|
||||||
|
)]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = 0.0,
|
||||||
|
maximum = 1.0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = 0.95
|
||||||
|
)]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
#[serde(default = "default_do_sample")]
|
#[serde(default = "default_do_sample")]
|
||||||
|
#[schema(default = "false", example = true)]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(inline, max_items = 4, example = json!(["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_temperature() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
fn default_repetition_penalty() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_k() -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_p() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_do_sample() -> bool {
|
fn default_do_sample() -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
@ -57,10 +66,10 @@ fn default_max_new_tokens() -> u32 {
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
GenerateParameters {
|
GenerateParameters {
|
||||||
temperature: default_temperature(),
|
temperature: None,
|
||||||
repetition_penalty: default_repetition_penalty(),
|
repetition_penalty: None,
|
||||||
top_k: default_top_k(),
|
top_k: None,
|
||||||
top_p: default_top_p(),
|
top_p: None,
|
||||||
do_sample: default_do_sample(),
|
do_sample: default_do_sample(),
|
||||||
max_new_tokens: default_max_new_tokens(),
|
max_new_tokens: default_max_new_tokens(),
|
||||||
stop: vec![],
|
stop: vec![],
|
||||||
|
@ -69,42 +78,77 @@ fn default_parameters() -> GenerateParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateRequest {
|
pub(crate) struct GenerateRequest {
|
||||||
|
#[schema(example = "My name is Olivier and I")]
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
#[serde(default = "default_parameters")]
|
#[serde(default = "default_parameters")]
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
pub struct Token(u32, String, f32);
|
pub struct Token {
|
||||||
|
#[schema(example = 0)]
|
||||||
|
id: u32,
|
||||||
|
#[schema(example = "test")]
|
||||||
|
text: String,
|
||||||
|
#[schema(nullable = true, example = -0.34)]
|
||||||
|
logprob: f32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
#[serde(rename_all(serialize = "snake_case"))]
|
||||||
|
pub(crate) enum FinishReason {
|
||||||
|
#[schema(rename = "length")]
|
||||||
|
Length,
|
||||||
|
#[serde(rename = "eos_token")]
|
||||||
|
#[schema(rename = "eos_token")]
|
||||||
|
EndOfSequenceToken,
|
||||||
|
#[schema(rename = "stop_sequence")]
|
||||||
|
StopSequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
pub finish_reason: String,
|
#[schema(example = "length")]
|
||||||
|
pub finish_reason: FinishReason,
|
||||||
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
|
#[schema(example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prefill: Option<Vec<Token>>,
|
pub prefill: Option<Vec<Token>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tokens: Option<Vec<Token>>,
|
pub tokens: Option<Vec<Token>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct GenerateResponse {
|
pub(crate) struct GenerateResponse {
|
||||||
|
#[schema(example = "test")]
|
||||||
pub generated_text: String,
|
pub generated_text: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct StreamResponse {
|
pub(crate) struct StreamDetails {
|
||||||
pub token: Token,
|
#[schema(example = "length")]
|
||||||
pub generated_text: Option<String>,
|
pub finish_reason: FinishReason,
|
||||||
pub details: Option<Details>,
|
#[schema(example = 1)]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
#[schema(example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct StreamResponse {
|
||||||
|
pub token: Token,
|
||||||
|
#[schema(nullable = true, default = "null", example = "test")]
|
||||||
|
pub generated_text: Option<String>,
|
||||||
|
#[schema(nullable = true, default = "null")]
|
||||||
|
pub details: Option<StreamDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct ErrorResponse {
|
pub(crate) struct ErrorResponse {
|
||||||
|
#[schema(inline)]
|
||||||
pub error: String,
|
pub error: String,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferStreamResponse};
|
||||||
use crate::{
|
use crate::{
|
||||||
Details, ErrorResponse, GenerateParameters, GenerateRequest, GenerateResponse, Infer,
|
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||||
StreamResponse, Validation,
|
Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
@ -19,6 +19,8 @@ use tokio::signal;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
#[instrument(skip(infer))]
|
||||||
|
@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
.generate(GenerateRequest {
|
.generate(GenerateRequest {
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
temperature: 1.0,
|
temperature: None,
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: None,
|
||||||
top_k: 0,
|
top_k: None,
|
||||||
top_p: 1.0,
|
top_p: None,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
stop: vec![],
|
stop: Vec::new(),
|
||||||
details: false,
|
details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
},
|
},
|
||||||
|
@ -47,7 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate method
|
/// Generate tokens
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/generate",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = [GenerateResponse]),
|
||||||
|
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
|
@ -76,7 +95,7 @@ async fn generate(
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => Some(Details {
|
||||||
finish_reason: response.generated_text.finish_reason,
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
prefill: Some(response.prefill),
|
prefill: Some(response.prefill),
|
||||||
tokens: Some(response.tokens),
|
tokens: Some(response.tokens),
|
||||||
|
@ -132,7 +151,29 @@ async fn generate(
|
||||||
Ok((headers, Json(response)))
|
Ok((headers, Json(response)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate stream method
|
/// Generate a stream of token using Server Side Events
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/generate_stream",
|
||||||
|
request_body = GenerateRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = [StreamResponse],
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Request failed during generation"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Model is overloaded"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Input validation error"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
||||||
|
example = json!({"error": "Incomplete generation"}),
|
||||||
|
content_type="text/event-stream "),
|
||||||
|
)
|
||||||
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
fields(
|
fields(
|
||||||
|
@ -185,11 +226,9 @@ async fn generate_stream(
|
||||||
} => {
|
} => {
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => Some(StreamDetails {
|
||||||
finish_reason: generated_text.finish_reason,
|
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||||
generated_tokens: generated_text.generated_tokens,
|
generated_tokens: generated_text.generated_tokens,
|
||||||
prefill: None,
|
|
||||||
tokens: None,
|
|
||||||
seed: generated_text.seed,
|
seed: generated_text.seed,
|
||||||
}),
|
}),
|
||||||
false => None,
|
false => None,
|
||||||
|
@ -265,6 +304,39 @@ pub async fn run(
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
) {
|
) {
|
||||||
|
// OpenAPI documentation
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
paths(
|
||||||
|
generate,
|
||||||
|
generate_stream,
|
||||||
|
),
|
||||||
|
components(
|
||||||
|
schemas(
|
||||||
|
GenerateRequest,
|
||||||
|
GenerateParameters,
|
||||||
|
Token,
|
||||||
|
GenerateResponse,
|
||||||
|
Details,
|
||||||
|
FinishReason,
|
||||||
|
StreamResponse,
|
||||||
|
StreamDetails,
|
||||||
|
ErrorResponse,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
tags(
|
||||||
|
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||||
|
),
|
||||||
|
info(
|
||||||
|
title = "Text Generation Inference",
|
||||||
|
license(
|
||||||
|
name = "Apache 2.0",
|
||||||
|
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
struct ApiDoc;
|
||||||
|
|
||||||
// Create state
|
// Create state
|
||||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
|
@ -277,6 +349,7 @@ pub async fn run(
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
.route("/", post(generate))
|
.route("/", post(generate))
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/generate_stream", post(generate_stream))
|
.route("/generate_stream", post(generate_stream))
|
||||||
|
@ -320,6 +393,17 @@ async fn shutdown_signal() {
|
||||||
tracing::info!("signal received, starting graceful shutdown");
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<i32> for FinishReason {
|
||||||
|
fn from(finish_reason: i32) -> Self {
|
||||||
|
let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
|
||||||
|
match finish_reason {
|
||||||
|
text_generation_client::FinishReason::Length => FinishReason::Length,
|
||||||
|
text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert to Axum supported formats
|
/// Convert to Axum supported formats
|
||||||
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
||||||
fn from(err: InferError) -> Self {
|
fn from(err: InferError) -> Self {
|
||||||
|
|
|
@ -110,30 +110,58 @@ fn validate(
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
rng: &mut ThreadRng,
|
rng: &mut ThreadRng,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
if request.parameters.temperature <= 0.0 {
|
let GenerateParameters {
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
stop: stop_sequences,
|
||||||
|
seed,
|
||||||
|
..
|
||||||
|
} = request.parameters;
|
||||||
|
|
||||||
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
|
if temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
}
|
}
|
||||||
if request.parameters.repetition_penalty <= 0.0 {
|
|
||||||
|
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
|
||||||
|
if repetition_penalty <= 0.0 {
|
||||||
return Err(ValidationError::RepetitionPenalty);
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
}
|
}
|
||||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
|
||||||
|
let top_p = top_p.unwrap_or(1.0);
|
||||||
|
if top_p <= 0.0 || top_p > 1.0 {
|
||||||
return Err(ValidationError::TopP);
|
return Err(ValidationError::TopP);
|
||||||
}
|
}
|
||||||
if request.parameters.top_k < 0 {
|
|
||||||
return Err(ValidationError::TopK);
|
// Different because the proto default value is 0 while it is not a valid value
|
||||||
}
|
// for the user
|
||||||
if request.parameters.max_new_tokens > MAX_MAX_NEW_TOKENS {
|
let top_k: u32 = match top_k {
|
||||||
|
None => Ok(0),
|
||||||
|
Some(top_k) => {
|
||||||
|
if top_k <= 0 {
|
||||||
|
return Err(ValidationError::TopK);
|
||||||
|
}
|
||||||
|
Ok(top_k as u32)
|
||||||
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
|
if max_new_tokens == 0 || max_new_tokens > MAX_MAX_NEW_TOKENS {
|
||||||
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
return Err(ValidationError::MaxNewTokens(MAX_MAX_NEW_TOKENS));
|
||||||
}
|
}
|
||||||
if request.parameters.stop.len() > MAX_STOP_SEQUENCES {
|
|
||||||
|
if stop_sequences.len() > MAX_STOP_SEQUENCES {
|
||||||
return Err(ValidationError::StopSequence(
|
return Err(ValidationError::StopSequence(
|
||||||
MAX_STOP_SEQUENCES,
|
MAX_STOP_SEQUENCES,
|
||||||
request.parameters.stop.len(),
|
stop_sequences.len(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If seed is None, assign a random one
|
// If seed is None, assign a random one
|
||||||
let seed = match request.parameters.seed {
|
let seed = match seed {
|
||||||
None => rng.gen(),
|
None => rng.gen(),
|
||||||
Some(seed) => seed,
|
Some(seed) => seed,
|
||||||
};
|
};
|
||||||
|
@ -147,21 +175,10 @@ fn validate(
|
||||||
Err(ValidationError::InputLength(input_length, max_input_length))
|
Err(ValidationError::InputLength(input_length, max_input_length))
|
||||||
} else {
|
} else {
|
||||||
// Return ValidGenerateRequest
|
// Return ValidGenerateRequest
|
||||||
let GenerateParameters {
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
stop: stop_sequences,
|
|
||||||
..
|
|
||||||
} = request.parameters;
|
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k: top_k as u32,
|
top_k,
|
||||||
top_p,
|
top_p,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
|
@ -206,7 +223,7 @@ pub enum ValidationError {
|
||||||
TopP,
|
TopP,
|
||||||
#[error("top_k must be strictly positive")]
|
#[error("top_k must be strictly positive")]
|
||||||
TopK,
|
TopK,
|
||||||
#[error("max_new_tokens must be <= {0}")]
|
#[error("max_new_tokens must be strictly positive and <= {0}")]
|
||||||
MaxNewTokens(u32),
|
MaxNewTokens(u32),
|
||||||
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
#[error("inputs must have less than {1} tokens. Given: {0}")]
|
||||||
InputLength(usize, usize),
|
InputLength(usize, usize),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# BLOOM Inference Python gRPC Server
|
# Text Generation Inference Python gRPC Server
|
||||||
|
|
||||||
A Python gRPC server for BLOOM Inference
|
A Python gRPC server for Text Generation Inference
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "BLOOM Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
|
|
|
@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||||
== "TestTestTestTestTestTestTestTestTestTest"
|
|
||||||
)
|
)
|
||||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
|
@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||||
== "TestTestTestTestTestTestTestTestTestTest"
|
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
@ -283,8 +281,7 @@ def test_batch_concatenate(
|
||||||
|
|
||||||
assert len(generations) == 2
|
assert len(generations) == 2
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||||
== "TestTestTestTestTestTestTestTestTestTest"
|
|
||||||
)
|
)
|
||||||
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
assert generations[0].request_id == default_bloom_batch.requests[0].id
|
||||||
assert (
|
assert (
|
||||||
|
@ -306,8 +303,7 @@ def test_batch_concatenate(
|
||||||
|
|
||||||
assert len(generations) == 1
|
assert len(generations) == 1
|
||||||
assert (
|
assert (
|
||||||
generations[0].generated_text.text
|
generations[0].generated_text.text == "TestTestTestTestTestTestTestTestTestTest"
|
||||||
== "TestTestTestTestTestTestTestTestTestTest"
|
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
|
||||||
|
|
|
@ -9,6 +9,7 @@ from text_generation.utils import (
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
|
FinishReason,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
|
||||||
def test_stopping_criteria():
|
def test_stopping_criteria():
|
||||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria(65827, "/test") == (False, None)
|
assert criteria(65827, "/test") == (False, None)
|
||||||
assert criteria(30, ";") == (True, "stop_sequence")
|
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria_eos():
|
def test_stopping_criteria_eos():
|
||||||
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(0, "") == (True, "eos_token")
|
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
def test_stopping_criteria_max():
|
def test_stopping_criteria_max():
|
||||||
|
@ -39,7 +40,7 @@ def test_stopping_criteria_max():
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, "length")
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
|
|
|
@ -13,7 +13,7 @@ app = typer.Typer()
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
|
@ -46,16 +46,16 @@ def serve(
|
||||||
os.getenv("MASTER_PORT", None) is not None
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
), "MASTER_PORT must be set when sharded is True"
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
server.serve(model_name, revision, sharded, quantize, uds_path)
|
server.serve(model_id, revision, sharded, quantize, uds_path)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def download_weights(
|
def download_weights(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
extension: str = ".safetensors",
|
extension: str = ".safetensors",
|
||||||
):
|
):
|
||||||
utils.download_weights(model_name, revision, extension)
|
utils.download_weights(model_id, revision, extension)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
|
model_id: str, revision: Optional[str], sharded: bool, quantize: bool
|
||||||
) -> Model:
|
) -> Model:
|
||||||
config = AutoConfig.from_pretrained(model_name, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
|
|
||||||
if config.model_type == "bloom":
|
if config.model_type == "bloom":
|
||||||
if sharded:
|
if sharded:
|
||||||
return BLOOMSharded(model_name, revision, quantize=quantize)
|
return BLOOMSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return BLOOM(model_name, revision, quantize=quantize)
|
return BLOOM(model_id, revision, quantize=quantize)
|
||||||
elif config.model_type == "gpt_neox":
|
elif config.model_type == "gpt_neox":
|
||||||
if sharded:
|
if sharded:
|
||||||
return GPTNeoxSharded(model_name, revision, quantize=quantize)
|
return GPTNeoxSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return GPTNeox(model_name, revision, quantize=quantize)
|
return GPTNeox(model_id, revision, quantize=quantize)
|
||||||
elif model_name.startswith("facebook/galactica"):
|
elif model_id.startswith("facebook/galactica"):
|
||||||
if sharded:
|
if sharded:
|
||||||
return GalacticaSharded(model_name, revision, quantize=quantize)
|
return GalacticaSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
return Galactica(model_name, revision, quantize=quantize)
|
return Galactica(model_id, revision, quantize=quantize)
|
||||||
elif "santacoder" in model_name:
|
elif "santacoder" in model_id:
|
||||||
return SantaCoder(model_name, revision, quantize)
|
return SantaCoder(model_id, revision, quantize)
|
||||||
else:
|
else:
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise ValueError("sharded is not supported for AutoModel")
|
||||||
try:
|
try:
|
||||||
return CausalLM(model_name, revision, quantize=quantize)
|
return CausalLM(model_id, revision, quantize=quantize)
|
||||||
except Exception:
|
except Exception:
|
||||||
return Seq2SeqLM(model_name, revision, quantize=quantize)
|
return Seq2SeqLM(model_id, revision, quantize=quantize)
|
||||||
|
|
|
@ -57,10 +57,10 @@ class BLOOM(CausalLM):
|
||||||
|
|
||||||
class BLOOMSharded(BLOOM):
|
class BLOOMSharded(BLOOM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_name.startswith("bigscience/bloom"):
|
if not model_id.startswith("bigscience/bloom"):
|
||||||
raise ValueError(f"Model {model_name} is not supported")
|
raise ValueError(f"Model {model_id} is not supported")
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
|
@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, revision=revision, slow_but_exact=False, tp_parallel=True
|
model_id, revision=revision, slow_but_exact=False, tp_parallel=True
|
||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
|
|
||||||
# Only download weights for small models
|
# Only download weights for small models
|
||||||
if self.master and model_name == "bigscience/bloom-560m":
|
if self.master and model_id == "bigscience/bloom-560m":
|
||||||
download_weights(model_name, revision=revision, extension=".safetensors")
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
model_name, revision=revision, extension=".safetensors"
|
|
||||||
)
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise ValueError("No safetensors weights found")
|
raise ValueError("No safetensors weights found")
|
||||||
|
|
||||||
|
|
|
@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
|
||||||
|
|
||||||
|
|
||||||
class CausalLM(Model):
|
class CausalLM(Model):
|
||||||
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -244,10 +244,10 @@ class CausalLM(Model):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
|
|
@ -149,10 +149,10 @@ class Galactica(CausalLM):
|
||||||
|
|
||||||
class GalacticaSharded(Galactica):
|
class GalacticaSharded(Galactica):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
if not model_name.startswith("facebook/galactica"):
|
if not model_id.startswith("facebook/galactica"):
|
||||||
raise ValueError(f"Model {model_name} is not supported")
|
raise ValueError(f"Model {model_id} is not supported")
|
||||||
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
|
@ -164,22 +164,20 @@ class GalacticaSharded(Galactica):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, revision=revision, tp_parallel=True
|
model_id, revision=revision, tp_parallel=True
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# Only download weights for small models
|
# Only download weights for small models
|
||||||
if self.master and model_name == "facebook/galactica-125m":
|
if self.master and model_id == "facebook/galactica-125m":
|
||||||
download_weights(model_name, revision=revision, extension=".safetensors")
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
model_name, revision=revision, extension=".safetensors"
|
|
||||||
)
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise ValueError("No safetensors weights found")
|
raise ValueError("No safetensors weights found")
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
|
||||||
|
|
||||||
class GPTNeoxSharded(GPTNeox):
|
class GPTNeoxSharded(GPTNeox):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, revision: Optional[str] = None, quantize: bool = False
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||||||
):
|
):
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
self.master = self.rank == 0
|
self.master = self.rank == 0
|
||||||
|
@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_name, revision=revision, tp_parallel=True
|
model_id, revision=revision, tp_parallel=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only master download weights
|
# Only master download weights
|
||||||
if self.master:
|
if self.master:
|
||||||
download_weights(model_name, revision=revision, extension=".safetensors")
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
model_name, revision=revision, extension=".safetensors"
|
|
||||||
)
|
|
||||||
if not filenames:
|
if not filenames:
|
||||||
raise ValueError("No safetensors weights found")
|
raise ValueError("No safetensors weights found")
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
|
||||||
|
|
||||||
|
|
||||||
class SantaCoder(CausalLM):
|
class SantaCoder(CausalLM):
|
||||||
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
|
@ -43,7 +43,7 @@ class SantaCoder(CausalLM):
|
||||||
|
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModelForCausalLM.from_pretrained(
|
AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize,
|
||||||
|
|
|
@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqLM(Model):
|
class Seq2SeqLM(Model):
|
||||||
def __init__(self, model_name: str, revision: Optional[str] = None, quantize=False):
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
|
@ -301,14 +301,14 @@ class Seq2SeqLM(Model):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
model_name,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
load_in_8bit=quantize,
|
load_in_8bit=quantize,
|
||||||
).eval()
|
).eval()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_name, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
tokenizer.bos_token_id = self.model.config.decoder_start_token_id
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import List, Optional
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
class Batch(ABC):
|
class Batch(ABC):
|
||||||
|
@ -38,7 +39,7 @@ class Batch(ABC):
|
||||||
class GeneratedText:
|
class GeneratedText:
|
||||||
text: str
|
text: str
|
||||||
generated_tokens: int
|
generated_tokens: int
|
||||||
finish_reason: str
|
finish_reason: FinishReason
|
||||||
seed: Optional[int]
|
seed: Optional[int]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
|
|
|
@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
|
|
||||||
|
|
||||||
def serve(
|
def serve(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool,
|
sharded: bool,
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_name: str,
|
model_id: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
|
@ -89,7 +89,7 @@ def serve(
|
||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
model = get_model(model_name, revision, sharded, quantize)
|
model = get_model(model_id, revision, sharded, quantize)
|
||||||
|
|
||||||
server = aio.server(interceptors=[ExceptionInterceptor()])
|
server = aio.server(interceptors=[ExceptionInterceptor()])
|
||||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||||
|
@ -109,4 +109,4 @@ def serve(
|
||||||
logger.info("Signal received. Shutting down")
|
logger.info("Signal received. Shutting down")
|
||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_name, revision, sharded, quantize))
|
asyncio.run(serve_inner(model_id, revision, sharded, quantize))
|
||||||
|
|
|
@ -24,9 +24,11 @@ from transformers.generation.logits_process import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
|
||||||
|
|
||||||
class Sampling:
|
class Sampling:
|
||||||
def __init__(self, seed: int, device: str = "cpu"):
|
def __init__(self, seed: int, device: str = "cpu"):
|
||||||
self.generator = torch.Generator(device)
|
self.generator = torch.Generator(device)
|
||||||
|
@ -129,15 +131,15 @@ class StoppingCriteria:
|
||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, "length"
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
if last_token == self.eos_token_id:
|
if last_token == self.eos_token_id:
|
||||||
return True, "eos_token"
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
self.current_output += last_output
|
self.current_output += last_output
|
||||||
for stop_sequence_criteria in self.stop_sequence_criterias:
|
for stop_sequence_criteria in self.stop_sequence_criterias:
|
||||||
if stop_sequence_criteria(self.current_output):
|
if stop_sequence_criteria(self.current_output):
|
||||||
return True, "stop_sequence"
|
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
|
||||||
|
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
@ -180,20 +182,20 @@ def initialize_torch_distributed():
|
||||||
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(model_name, revision=None, extension=".safetensors"):
|
def weight_hub_files(model_id, revision=None, extension=".safetensors"):
|
||||||
"""Get the safetensors filenames on the hub"""
|
"""Get the safetensors filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
info = api.model_info(model_name, revision=revision)
|
info = api.model_info(model_id, revision=revision)
|
||||||
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(model_name, revision, filename):
|
def try_to_load_from_cache(model_id, revision, filename):
|
||||||
"""Try to load a file from the Hugging Face cache"""
|
"""Try to load a file from the Hugging Face cache"""
|
||||||
if revision is None:
|
if revision is None:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
object_id = model_name.replace("/", "--")
|
object_id = model_id.replace("/", "--")
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
if not repo_cache.is_dir():
|
||||||
|
@ -228,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename):
|
||||||
return str(cached_file) if cached_file.is_file() else None
|
return str(cached_file) if cached_file.is_file() else None
|
||||||
|
|
||||||
|
|
||||||
def weight_files(model_name, revision=None, extension=".safetensors"):
|
def weight_files(model_id, revision=None, extension=".safetensors"):
|
||||||
"""Get the local safetensors filenames"""
|
"""Get the local safetensors filenames"""
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||||
|
|
||||||
filenames = weight_hub_files(model_name, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
files = []
|
files = []
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
cache_file = try_to_load_from_cache(
|
cache_file = try_to_load_from_cache(
|
||||||
model_name, revision=revision, filename=filename
|
model_id, revision=revision, filename=filename
|
||||||
)
|
)
|
||||||
if cache_file is None:
|
if cache_file is None:
|
||||||
raise LocalEntryNotFoundError(
|
raise LocalEntryNotFoundError(
|
||||||
f"File {filename} of model {model_name} not found in "
|
f"File {filename} of model {model_id} not found in "
|
||||||
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
||||||
f"Please run `text-generation-server download-weights {model_name}` first."
|
f"Please run `text-generation-server download-weights {model_id}` first."
|
||||||
)
|
)
|
||||||
files.append(cache_file)
|
files.append(cache_file)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def download_weights(model_name, revision=None, extension=".safetensors"):
|
def download_weights(model_id, revision=None, extension=".safetensors"):
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
if WEIGHTS_CACHE_OVERRIDE is not None:
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
||||||
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
|
||||||
|
|
||||||
filenames = weight_hub_files(model_name, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
|
|
||||||
download_function = partial(
|
download_function = partial(
|
||||||
hf_hub_download,
|
hf_hub_download,
|
||||||
repo_id=model_name,
|
repo_id=model_id,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue