This commit is contained in:
Olivier Dehaene 2022-10-18 15:19:03 +02:00 committed by OlivierDehaene
parent 92c1ecd008
commit f16f2f5ae1
36 changed files with 1556 additions and 677 deletions

View File

@ -1,2 +1,2 @@
aml aml
router/target target

1
.gitignore vendored
View File

@ -1 +1,2 @@
.idea .idea
target

View File

@ -55,9 +55,9 @@ dependencies = [
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.57" version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -166,9 +166,9 @@ dependencies = [
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.11.0" version = "3.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
[[package]] [[package]]
name = "byteorder" name = "byteorder"
@ -255,9 +255,9 @@ dependencies = [
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.0.15" version = "4.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f" checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
dependencies = [ dependencies = [
"atty", "atty",
"bitflags", "bitflags",
@ -391,6 +391,16 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "ctrlc"
version = "3.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d91974fbbe88ec1df0c24a4f00f99583667a7e2e6272b2b92d294d81e462173"
dependencies = [
"nix",
"winapi",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.10.2" version = "0.10.2"
@ -529,7 +539,7 @@ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"redox_syscall", "redox_syscall",
"windows-sys", "windows-sys 0.36.1",
] ]
[[package]] [[package]]
@ -936,9 +946,9 @@ dependencies = [
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.3" version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754" checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
[[package]] [[package]]
name = "js-sys" name = "js-sys"
@ -957,9 +967,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.134" version = "0.2.135"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
[[package]] [[package]]
name = "lock_api" name = "lock_api"
@ -1047,7 +1057,7 @@ dependencies = [
"libc", "libc",
"log", "log",
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys", "windows-sys 0.36.1",
] ]
[[package]] [[package]]
@ -1074,6 +1084,18 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "nix"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb"
dependencies = [
"autocfg",
"bitflags",
"cfg-if",
"libc",
]
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.1" version = "7.1.1"
@ -1084,6 +1106,16 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]] [[package]]
name = "num_cpus" name = "num_cpus"
version = "1.13.1" version = "1.13.1"
@ -1185,6 +1217,12 @@ version = "6.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.1" version = "0.12.1"
@ -1197,15 +1235,15 @@ dependencies = [
[[package]] [[package]]
name = "parking_lot_core" name = "parking_lot_core"
version = "0.9.3" version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"redox_syscall", "redox_syscall",
"smallvec", "smallvec",
"windows-sys", "windows-sys 0.42.0",
] ]
[[package]] [[package]]
@ -1300,9 +1338,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.46" version = "1.0.47"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -1530,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2"
dependencies = [ dependencies = [
"lazy_static", "lazy_static",
"windows-sys", "windows-sys 0.36.1",
] ]
[[package]] [[package]]
@ -1584,9 +1622,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.85" version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44" checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -1625,6 +1663,15 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.7" version = "0.4.7"
@ -1681,10 +1728,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]] [[package]]
name = "syn" name = "subprocess"
version = "1.0.101" version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2" checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "syn"
version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1741,13 +1798,24 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "text-generation-launcher"
version = "0.1.0"
dependencies = [
"clap 4.0.17",
"ctrlc",
"subprocess",
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "text-generation-router" name = "text-generation-router"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"bloom-inference-client", "bloom-inference-client",
"clap 4.0.15", "clap 4.0.17",
"futures", "futures",
"parking_lot", "parking_lot",
"serde", "serde",
@ -1872,6 +1940,7 @@ dependencies = [
"num_cpus", "num_cpus",
"parking_lot", "parking_lot",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
"winapi", "winapi",
@ -1910,9 +1979,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.10" version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af" checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"pin-project-lite", "pin-project-lite",
@ -2031,9 +2100,9 @@ dependencies = [
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.1" version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]] [[package]]
name = "tower-service" name = "tower-service"
@ -2043,9 +2112,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.36" version = "0.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"log", "log",
@ -2056,9 +2125,9 @@ dependencies = [
[[package]] [[package]]
name = "tracing-attributes" name = "tracing-attributes"
version = "0.1.22" version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2067,9 +2136,9 @@ dependencies = [
[[package]] [[package]]
name = "tracing-core" name = "tracing-core"
version = "0.1.29" version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7" checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"valuable", "valuable",
@ -2108,11 +2177,11 @@ dependencies = [
[[package]] [[package]]
name = "tracing-subscriber" name = "tracing-subscriber"
version = "0.3.15" version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b" checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [ dependencies = [
"ansi_term", "nu-ansi-term",
"sharded-slab", "sharded-slab",
"smallvec", "smallvec",
"thread_local", "thread_local",
@ -2140,9 +2209,9 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.4" version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd" checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
[[package]] [[package]]
name = "unicode-normalization" name = "unicode-normalization"
@ -2361,43 +2430,100 @@ version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
dependencies = [ dependencies = [
"windows_aarch64_msvc", "windows_aarch64_msvc 0.36.1",
"windows_i686_gnu", "windows_i686_gnu 0.36.1",
"windows_i686_msvc", "windows_i686_msvc 0.36.1",
"windows_x86_64_gnu", "windows_x86_64_gnu 0.36.1",
"windows_x86_64_msvc", "windows_x86_64_msvc 0.36.1",
] ]
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc 0.42.0",
"windows_i686_gnu 0.42.0",
"windows_i686_msvc 0.42.0",
"windows_x86_64_gnu 0.42.0",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc 0.42.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.36.1" version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.36.1" version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
[[package]]
name = "windows_i686_gnu"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.36.1" version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
[[package]]
name = "windows_i686_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.36.1" version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.36.1" version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
[[package]] [[package]]
name = "winreg" name = "winreg"
version = "0.10.1" version = "0.10.1"

11
Cargo.toml Normal file
View File

@ -0,0 +1,11 @@
[workspace]
members = [
"router",
"router/client",
"launcher"
]
[profile.release]
debug = 1
incremental = true
lto = "off"

View File

@ -1,4 +1,4 @@
FROM rust:1.64 as builder FROM rust:1.64 as router-builder
WORKDIR /usr/src WORKDIR /usr/src
@ -9,7 +9,17 @@ WORKDIR /usr/src/router
RUN cargo install --path . RUN cargo install --path .
FROM nvidia/cuda:11.6.1-devel-ubuntu18.04 FROM rust:1.64 as launcher-builder
WORKDIR /usr/src
COPY launcher launcher
WORKDIR /usr/src/launcher
RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ENV LANG=C.UTF-8 \ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \ LC_ALL=C.UTF-8 \
@ -34,17 +44,15 @@ RUN cd ~ && \
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \ bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
conda create -n text-generation python=3.9 -y conda create -n text-generation python=3.9 -y
WORKDIR /usr/src
COPY server/Makefile server/Makefile
# Install specific version of torch # Install specific version of torch
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir RUN cd server && make install-torch
# Install specific version of transformers # Install specific version of transformers
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \ RUN cd server && make install-transformers
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
/opt/miniconda/envs/text-generation/bin/python setup.py install
WORKDIR /usr/src
# Install server # Install server
COPY server server COPY server server
@ -52,9 +60,7 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router # Install router
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
COPY run.sh . CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH
RUN chmod +x run.sh
CMD ["./run.sh"]

19
Makefile Normal file
View File

@ -0,0 +1,19 @@
install-server:
cd server && make pip-install
install-router:
cd router && cargo install --path .
install-launcher:
cd launcher && cargo install --path .
install:
make install-server
make install-router
make install-launcher
run-bloom-560m:
text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2
run-bloom:
text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8

View File

@ -1,50 +1,51 @@
# Text Generation Inference # LLM Text Generation Inference
A Rust and gRPC server for text generation inference. <div align="center">
## Load Tests ![architecture](assets/architecture.jpg)
</div>
A Rust and gRPC server for large language models text generation inference.
## Load Tests for BLOOM
See `k6/load_test.js` See `k6/load_test.js`
We send the default examples with a 1 second delay between each request. We send the default examples with a 1 second delay between requests.
Stages: Stages:
- Ramp up to 50 concurrent requests per second in 1min - Ramp up to 50 vus in 1min
- Ramp up from 50 to 100 concurrent requests per second in 2min - Ramp up from 50 to 100 vus in 2min
- Ramp down to 0 concurrent requests per second in 1min - Ramp down to 0 vus in 1min
| | avg | min | med | max | p(90) | p(95) | RPS | | | avg | min | med | max | p(90) | p(95) | RPS |
|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------| |--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 | | [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 | | ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** | | New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
## Install ## Install
```shell ```shell
cd server make install
pip install .
```
```
cd router
cargo build --release
``` ```
## Run ## Run
```shell ```shell
python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models make run-bloom-560m
``` ```
## Test
```shell ```shell
./router/target/release/router curl 127.0.0.1:3000/generate \
-X POST \
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
-H 'Content-Type: application/json'
``` ```
## TODO: ## TODO:
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated - [ ] Add tests for the `server/model` logic
- [ ] Add tests
- [ ] Add shutdown logic in router and server
- [ ] Improve multi-processing logic in server
- [ ] Improve past key layer indexing?

View File

@ -8,7 +8,7 @@ environment_variables:
MODEL_NAME: bigscience/bloom MODEL_NAME: bigscience/bloom
NUM_GPUS: 8 NUM_GPUS: 8
environment: environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3 image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2
inference_config: inference_config:
liveness_route: liveness_route:
port: 3000 port: 3000

BIN
assets/architecture.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

13
launcher/Cargo.toml Normal file
View File

@ -0,0 +1,13 @@
[package]
name = "text-generation-launcher"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Launcher"
[dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = "3.2.3"
subprocess = "0.2.9"
tracing = "0.1.37"
tracing-subscriber = "0.3.16"

358
launcher/src/main.rs Normal file
View File

@ -0,0 +1,358 @@
use clap::Parser;
use std::io::{BufRead, BufReader, Read};
use std::path::Path;
use std::process::ExitCode;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::TryRecvError;
use std::sync::Arc;
use std::sync::{mpsc, Mutex};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "bigscience/bloom-560m", long, env)]
model_name: String,
#[clap(long, env)]
num_shard: Option<usize>,
#[clap(long, env)]
shard_directory: Option<String>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
shard_uds_path: String,
#[clap(default_value = "localhost", long, env)]
master_addr: String,
#[clap(default_value = "29500", long, env)]
master_port: usize,
}
fn main() -> ExitCode {
tracing_subscriber::fmt::init();
// Pattern match configuration
let Args {
model_name,
num_shard,
shard_directory,
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
port,
shard_uds_path,
master_addr,
master_port,
} = Args::parse();
// By default we only have one master shard
let num_shard = num_shard.unwrap_or(1);
// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
ctrlc::set_handler(move || {
r.store(false, Ordering::SeqCst);
})
.expect("Error setting Ctrl-C handler");
// Shared shutdown bool
let shutdown = Arc::new(Mutex::new(false));
// Shared shutdown channel
// When shutting down, the main thread will wait for all senders to be dropped
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
// Shared channel to track shard status
let (status_sender, status_receiver) = mpsc::channel();
// Start shard processes
for rank in 0..num_shard {
let model_name = model_name.clone();
let uds_path = shard_uds_path.clone();
let shard_directory = shard_directory.clone();
let master_addr = master_addr.clone();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || {
shard_manager(
model_name,
uds_path,
shard_directory,
rank,
num_shard,
master_addr,
master_port,
status_sender,
shutdown,
shutdown_sender,
)
});
}
drop(shutdown_sender);
// Wait for shard to start
let mut shard_ready = 0;
while running.load(Ordering::SeqCst) {
match status_receiver.try_recv() {
Ok(ShardStatus::Ready) => {
shard_ready += 1;
if shard_ready == num_shard {
break;
}
}
Err(TryRecvError::Empty) => {
sleep(Duration::from_millis(100));
}
Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err);
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
}
}
// We might have received a termination signal
if !running.load(Ordering::SeqCst) {
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::SUCCESS;
}
// All shard started
// Start webserver
tracing::info!("Starting Webserver");
let mut webserver = match Popen::create(
&[
"text-generation-router",
"--max-concurrent-requests",
&max_concurrent_requests.to_string(),
"--max-input-length",
&max_input_length.to_string(),
"--max-batch-size",
&max_batch_size.to_string(),
"--max-waiting-time",
&max_waiting_time.to_string(),
"--port",
&port.to_string(),
"--master-shard-uds-path",
&format!("{}-0", shard_uds_path),
"--tokenizer-name",
&model_name,
],
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
tracing::error!("Failed to start webserver: {}", err);
if let PopenError::IoError(err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("text-generation-router not found in PATH");
tracing::error!("Please install it with `make install-router`")
}
}
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
};
// Redirect STDOUT and STDERR to the console
let webserver_stdout = webserver.stdout.take().unwrap();
let webserver_stderr = webserver.stderr.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(webserver_stdout);
let stderr = BufReader::new(webserver_stderr);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
for line in stderr.lines() {
println!("{}", line.unwrap());
}
});
// Default exit code
let mut exit_code = ExitCode::SUCCESS;
while running.load(Ordering::SeqCst) {
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
tracing::error!("Shard {} failed:\n{}", rank, err);
exit_code = ExitCode::FAILURE;
break;
};
match webserver.poll() {
Some(_) => {
tracing::error!("Webserver Crashed");
shutdown_shards(shutdown, &shutdown_receiver);
return ExitCode::FAILURE;
}
None => {
sleep(Duration::from_millis(100));
}
};
}
// Graceful termination
webserver.terminate().unwrap();
tracing::info!("Waiting for webserver to gracefully shutdown");
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
tracing::info!("Webserver terminated");
shutdown_shards(shutdown, &shutdown_receiver);
exit_code
}
#[derive(Debug)]
enum ShardStatus {
Ready,
Failed((usize, String)),
}
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_name: String,
uds_path: String,
shard_directory: Option<String>,
rank: usize,
world_size: usize,
master_addr: String,
master_port: usize,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
let uds_string = format!("{}-{}", uds_path, rank);
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();
// Process args
let mut shard_argv = vec![
"bloom-inference-server".to_string(),
"serve".to_string(),
model_name,
"--uds-path".to_string(),
uds_path,
];
if world_size > 1 {
shard_argv.push("--sharded".to_string());
}
if let Some(shard_directory) = shard_directory {
shard_argv.push("--shard-directory".to_string());
shard_argv.push(shard_directory);
}
// Start process
tracing::info!("Starting shard {}", rank);
let mut p = match Popen::create(
&shard_argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Pipe,
// Needed for the shutdown procedure
setpgid: true,
// NCCL env vars
env: Some(vec![
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
(
"WORLD_SIZE".parse().unwrap(),
world_size.to_string().parse().unwrap(),
),
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
(
"MASTER_PORT".parse().unwrap(),
master_port.to_string().parse().unwrap(),
),
]),
..Default::default()
},
) {
Ok(p) => p,
Err(err) => {
if let PopenError::IoError(ref err) = err {
if err.kind() == io::ErrorKind::NotFound {
tracing::error!("bloom-inference-server not found in PATH");
tracing::error!("Please install it with `make install-server`")
}
}
status_sender
.send(ShardStatus::Failed((rank, err.to_string())))
.unwrap();
return;
}
};
let mut ready = false;
let start_time = Instant::now();
loop {
// Process exited
if p.poll().is_some() {
let mut err = String::new();
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
status_sender
.send(ShardStatus::Failed((rank, err)))
.unwrap();
return;
}
// We received a shutdown signal
if *shutdown.lock().unwrap() {
p.terminate().unwrap();
let _ = p.wait_timeout(Duration::from_secs(90));
tracing::info!("Shard {} terminated", rank);
return;
}
// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready {
tracing::info!("Waiting for shard {} to be ready...", rank);
}
sleep(Duration::from_secs(5));
}
}
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
tracing::info!("Shutting down shards");
// Update shutdown value to true
// This will be picked up by the shard manager
{
let mut shutdown = shutdown.lock().unwrap();
*shutdown = true;
}
// Wait for shards to shutdown
// This will block till all shutdown_sender are dropped
let _ = shutdown_receiver.recv();
}

View File

@ -11,10 +11,6 @@ service TextGenerationService {
rpc Generate (GenerateRequest) returns (GenerateResponse); rpc Generate (GenerateRequest) returns (GenerateResponse);
/// Generate tokens for a list of cached batches /// Generate tokens for a list of cached batches
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse); rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
/// Generate tokens until the text of at least one request of the batch is generated
rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse);
/// Generate tokens until the text of at least one request of the cached batches i finished
rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse);
} }
/// Empty request /// Empty request
@ -92,27 +88,3 @@ message GenerateWithCacheResponse {
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional Batch batch = 2;
} }
message GenerateUntilFinishedRequest {
/// Batch
Batch batch = 1;
}
message GenerateUntilFinishedResponse {
/// Finished requests
repeated GeneratedText generated_texts = 1;
/// Next batch (cached)
optional Batch batch = 2;
}
message GenerateUntilFinishedWithCacheRequest {
/// Cached batches
repeated Batch batches = 1;
}
message GenerateUntilFinishedWithCacheResponse {
/// Finished requests
repeated GeneratedText generated_texts = 1;
/// Next batch (cached)
optional Batch batch = 2;
}

1
router/.gitignore vendored
View File

@ -1 +0,0 @@
/target

View File

@ -22,16 +22,7 @@ serde = "1.0.145"
serde_json = "1.0.85" serde_json = "1.0.85"
thiserror = "1.0.37" thiserror = "1.0.37"
tokenizers = "0.13.0" tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] } tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tracing = "0.1.36" tracing = "0.1.36"
tracing-subscriber = "0.3.15" tracing-subscriber = "0.3.15"
[workspace]
members = [
"client",
]
[profile.release]
debug = 1
incremental = true
lto = "off"

View File

@ -5,8 +5,6 @@ edition = "2021"
[dependencies] [dependencies]
futures = "^0.3" futures = "^0.3"
#grpc-error-details = { path = "../../grpc-error-details" }
#grpc-metadata = { path = "../../grpc-metadata" }
prost = "^0.9" prost = "^0.9"
thiserror = "^1.0" thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] } tokio = { version = "^1.21", features = ["sync"] }

View File

@ -1,10 +1,11 @@
/// Single shard Client
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*; use crate::pb::generate::v1::*;
use crate::Result; use crate::Result;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::*; use tracing::*;
/// BLOOM Inference gRPC client /// Text Generation Inference gRPC client
#[derive(Clone)] #[derive(Clone)]
pub struct Client { pub struct Client {
stub: TextGenerationServiceClient<Channel>, stub: TextGenerationServiceClient<Channel>,
@ -34,6 +35,7 @@ impl Client {
}) })
} }
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> { pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}); let request = tonic::Request::new(ServiceDiscoveryRequest {});
@ -46,6 +48,7 @@ impl Client {
.into_inner() .into_inner()
.urls .urls
.into_iter() .into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") { .map(|url| match url.strip_prefix("unix://") {
None => url, None => url,
Some(stripped_url) => stripped_url.to_string(), Some(stripped_url) => stripped_url.to_string(),
@ -54,6 +57,7 @@ impl Client {
Ok(urls) Ok(urls)
} }
/// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> { pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {}); let request = tonic::Request::new(ClearCacheRequest {});
@ -64,6 +68,10 @@ impl Client {
Ok(()) Ok(())
} }
/// Generate one token for each request in the given batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> { pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) }); let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
@ -76,6 +84,10 @@ impl Client {
Ok((response.generated_texts, response.batch)) Ok((response.generated_texts, response.batch))
} }
/// Generate one token for each request in the given cached batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn generate_with_cache( pub async fn generate_with_cache(
&mut self, &mut self,
@ -90,34 +102,4 @@ impl Client {
.into_inner(); .into_inner();
Ok((response.generated_texts, response.batch)) Ok((response.generated_texts, response.batch))
} }
#[instrument(skip(self))]
pub async fn generate_until_finished(
&mut self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
let response = self
.stub
.generate_until_finished(request)
.instrument(info_span!("generate_until_finished"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
}
#[instrument(skip(self))]
pub async fn generate_until_finished_with_cache(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
let response = self
.stub
.generate_until_finished_with_cache(request)
.instrument(info_span!("generate_until_finished_with_cache"))
.await?
.into_inner();
Ok((response.generated_texts, response.batch))
}
} }

View File

@ -1,6 +1,7 @@
//! BLOOM Inference gRPC client library //! Text Generation gRPC client library
mod client; mod client;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb; mod pb;
mod sharded_client; mod sharded_client;
@ -8,7 +9,7 @@ pub use client::Client;
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request}; pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
pub use tonic::transport; use tonic::transport;
use tonic::Status; use tonic::Status;
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
@ -21,7 +22,7 @@ pub enum ClientError {
impl From<Status> for ClientError { impl From<Status> for ClientError {
fn from(err: Status) -> Self { fn from(err: Status) -> Self {
Self::Generation(err.to_string()) Self::Generation(err.message().to_string())
} }
} }

View File

@ -1,9 +1,11 @@
/// Multi shard Client
use crate::Result; use crate::Result;
use crate::{Batch, Client, GeneratedText}; use crate::{Batch, Client, GeneratedText};
use futures::future::join_all; use futures::future::join_all;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri; use tonic::transport::Uri;
/// List of all available commands that can be sent through the command channel
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum Command { enum Command {
Generate( Generate(
@ -14,36 +16,32 @@ enum Command {
Vec<Batch>, Vec<Batch>,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>, mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
), ),
GenerateUntilFinished(
Batch,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
GenerateUntilFinishedWithCache(
Vec<Batch>,
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
),
ClearCache(mpsc::Sender<Result<()>>), ClearCache(mpsc::Sender<Result<()>>),
} }
/// Tokio task that handles the communication with a single shard
///
/// We subscribe on a broadcast channel to receive commands that will be sent by
/// the ShardedClient.
///
/// Each command is fan out to all shards.
///
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
/// producer = the shards, single consumer = the ShardedClient).
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) { async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
while let Ok(message) = request_subscriber.recv().await { while let Ok(message) = request_subscriber.recv().await {
match message { match message {
Command::Generate(batch, response_tx) => { Command::Generate(batch, response_tx) => {
let result = client.generate(batch).await; let result = client.generate(batch).await;
// We can unwrap_or(()) here because the only error that can happen is if the
// receiver is dropped, which means that the ShardedClient already received a
// response from another shard
response_tx.try_send(result).unwrap_or(()); response_tx.try_send(result).unwrap_or(());
} }
Command::GenerateWithCache(batches, response_tx) => { Command::GenerateWithCache(batches, response_tx) => {
let result = client.generate_with_cache(batches).await; let result = client.generate_with_cache(batches).await;
response_tx.try_send(result).unwrap_or(()); response_tx.try_send(result).unwrap_or(());
} }
Command::GenerateUntilFinished(batch, response_tx) => {
let result = client.generate_until_finished(batch).await;
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
let result = client.generate_until_finished_with_cache(batches).await;
response_tx.try_send(result).unwrap_or(());
}
Command::ClearCache(response_tx) => { Command::ClearCache(response_tx) => {
let result = client.clear_cache().await; let result = client.clear_cache().await;
response_tx.try_send(result).unwrap_or(()); response_tx.try_send(result).unwrap_or(());
@ -52,30 +50,42 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece
} }
} }
/// Text Generation Inference gRPC multi client
pub struct ShardedClient { pub struct ShardedClient {
_clients: Vec<Client>,
request_tx: broadcast::Sender<Command>, request_tx: broadcast::Sender<Command>,
} }
impl ShardedClient { impl ShardedClient {
fn new(mut clients: Vec<Client>) -> Self { fn new(clients: Vec<Client>) -> Self {
// The broadcast channel to communicate with the shards
// We use a capacity of one as the shards are not asynchronous and can only process one
// command at a time
let (request_tx, _) = broadcast::channel(1); let (request_tx, _) = broadcast::channel(1);
for client in clients.drain(..) { // Spawn client tasks
for client in clients.iter() {
let request_subscriber = request_tx.subscribe(); let request_subscriber = request_tx.subscribe();
tokio::spawn(client_task(client, request_subscriber)); tokio::spawn(client_task(client.clone(), request_subscriber));
} }
Self { request_tx } Self {
_clients: clients,
request_tx,
}
} }
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> { async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await.unwrap(); let uris = master_client.service_discovery().await.unwrap();
let futures = uris.into_iter().map(|path| Client::connect_uds(path)); let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect(); let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?)) Ok(Self::new(clients?))
} }
/// Returns a client connected to the given url /// Returns a client connected to the given uri
pub async fn connect(uri: Uri) -> Result<Self> { pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?; let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
@ -87,51 +97,43 @@ impl ShardedClient {
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
} }
/// Generate one token for each request in the given batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> { pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1); let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx self.request_tx
.send(Command::Generate(batch, response_tx)) .send(Command::Generate(batch, response_tx))
.unwrap(); .unwrap();
// As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap() response_rx.recv().await.unwrap()
} }
/// Generate one token for each request in the given cached batch
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub async fn generate_with_cache( pub async fn generate_with_cache(
&self, &self,
batches: Vec<Batch>, batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> { ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1); let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx self.request_tx
.send(Command::GenerateWithCache(batches, response_tx)) .send(Command::GenerateWithCache(batches, response_tx))
.unwrap(); .unwrap();
// As soon as we receive one response, we can return as all shards will return the same
response_rx.recv().await.unwrap() response_rx.recv().await.unwrap()
} }
pub async fn generate_until_finished( /// Clear the past generations cache
&self,
batch: Batch,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateUntilFinished(batch, response_tx))
.unwrap();
response_rx.recv().await.unwrap()
}
pub async fn generate_until_finished_with_cache(
&self,
batches: Vec<Batch>,
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateUntilFinishedWithCache(
batches,
response_tx,
))
.unwrap();
response_rx.recv().await.unwrap()
}
pub async fn clear_cache(&self) -> Result<()> { pub async fn clear_cache(&self) -> Result<()> {
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let (response_tx, mut response_rx) = mpsc::channel(1); let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx self.request_tx
.send(Command::ClearCache(response_tx)) .send(Command::ClearCache(response_tx))

View File

@ -1,129 +1,158 @@
use crate::server::GenerateRequest; /// Batching and inference logic
use crate::GenerateRequest;
use crate::{Db, Entry}; use crate::{Db, Entry};
use axum::http::StatusCode; use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{oneshot, Notify}; use tokio::sync::{oneshot, Notify};
use tokio::time::Instant;
use tracing::instrument;
const MAX_LENGTH: usize = 128; /// Batcher
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded,
}
impl From<InferError> for (StatusCode, String) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Batcher { pub struct Batcher {
/// Request database
db: Db, db: Db,
/// Shared state
shared: Arc<Shared>, shared: Arc<Shared>,
} }
/// Batcher shared state
struct Shared { struct Shared {
/// Batching background Tokio task notifier
batching_task: Notify, batching_task: Notify,
} }
impl Batcher { impl Batcher {
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self { pub(crate) fn new(
client: ShardedClient,
max_batch_size: usize,
max_waiting_time: Duration,
) -> Self {
// Batcher shared state
let db = Db::new(); let db = Db::new();
let shared = Arc::new(Shared { let shared = Arc::new(Shared {
batching_task: Notify::new(), batching_task: Notify::new(),
}); });
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone())); // Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
max_batch_size,
max_waiting_time,
client,
db.clone(),
shared.clone(),
));
Self { db, shared } Self { db, shared }
} }
/// Add a new request to the database and return a future that will generate the text
pub(crate) async fn infer( pub(crate) async fn infer(
&self, &self,
input_length: usize, input_length: usize,
request: GenerateRequest, request: GenerateRequest,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
if self.db.len() > MAX_LENGTH { // One shot channel to communicate with the background batching task
return Err(InferError::Overloaded); let (response_tx, response_rx) = oneshot::channel();
}
let (request_tx, request_rx) = oneshot::channel(); // Try to append the request to the database
self.db.append(Entry { self.db.append(Entry {
request, request,
response_tx: request_tx, response_tx,
input_length, input_length,
time: Instant::now(),
}); });
// Notify the background task that we have a new entry in the database that needs
// to be batched
self.shared.batching_task.notify_waiters(); self.shared.batching_task.notify_waiters();
match request_rx.await.unwrap() {
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
match response_rx.await.unwrap() {
Ok(output) => Ok(output), Ok(output) => Ok(output),
Err(err) => Err(InferError::GenerationError(err.to_string())), Err(err) => Err(InferError::GenerationError(err.to_string())),
} }
} }
} }
async fn batching_task(max_batch_size: usize, /// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
max_batch_size: usize,
max_waiting_time: Duration,
client: ShardedClient, client: ShardedClient,
db: Db, db: Db,
shared: Arc<Shared>) { shared: Arc<Shared>,
) {
// Minimum batch size after which we try to add more requests
let limit_min_batch_size = (max_batch_size / 2) as u32; let limit_min_batch_size = (max_batch_size / 2) as u32;
// Infinite loop
loop { loop {
// Wait for a notification from the Batcher struct
shared.batching_task.notified().await; shared.batching_task.notified().await;
if let Some(batch) = db.next_batch(max_batch_size) { // Get the next batch from the DB
let request_ids = batch.requests.iter().map(|req| req.id).collect(); // This batch might be smaller than the maximum batch size if there are not enough requests
let mut cached_batch = match batch.size { // waiting in the DB
size if size > limit_min_batch_size => { if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
wrap_future(client.generate_until_finished(batch), request_ids, &db).await let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
}
_ => wrap_future(client.generate(batch), request_ids, &db).await,
};
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch { while let Some(batch) = cached_batch {
let mut current_batch_size = batch.size; // Get current batch info
let batch_size = batch.size;
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect(); let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch]; let mut batches = vec![batch];
if current_batch_size <= limit_min_batch_size { // If the current batch is too small, we try to add more requests to it
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) { if batch_size <= limit_min_batch_size {
let new_batch_request_ids = // Get the next batch from the DB that meet our minimum size criteria
new_batch.requests.iter().map(|req| req.id).collect(); if let Some((new_request_ids, new_batch)) =
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
{
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch =
wrap_future(client.generate(new_batch), new_batch_request_ids, &db) wrap_future(client.generate(new_batch), new_request_ids, &db).await;
.await; // Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
// If we don't have enough requests to meet the minimum size criteria, we
// try to get the next batch from the DB that have been waiting over
// the max_waiting_time
else if let Some((new_request_ids, new_batch)) =
db.next_batch(None, max_batch_size, Some(max_waiting_time))
{
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch { if let Some(new_cached_batch) = new_cached_batch {
current_batch_size += new_cached_batch.size;
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch); batches.push(new_cached_batch);
} }
} }
} }
cached_batch = match current_batch_size { cached_batch =
size if size > limit_min_batch_size => { wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
wrap_future(
client.generate_until_finished_with_cache(batches),
request_ids,
&db,
)
.await
}
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
};
} }
} }
} }
} }
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
request_ids: Vec<u64>, request_ids: Vec<u64>,
@ -134,6 +163,7 @@ async fn wrap_future(
send_generated(generated_texts, db); send_generated(generated_texts, db);
next_batch next_batch
} }
// If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
send_error(err, request_ids, db); send_error(err, request_ids, db);
None None
@ -141,16 +171,20 @@ async fn wrap_future(
} }
} }
/// Send errors to the Batcher for all `request_ids`
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) { fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
request_ids.into_iter().for_each(|id| { request_ids.into_iter().for_each(|id| {
// We can `expect` here as the request id should always be in the DB
let entry = db.remove(&id).expect("ID not found in db. This is a bug."); let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(()); entry.response_tx.send(Err(error.clone())).unwrap_or(());
}); });
} }
/// Send `generated_text` to the Batcher for all `finished`
fn send_generated(finished: Vec<GeneratedText>, db: &Db) { fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
finished.into_iter().for_each(|output| { finished.into_iter().for_each(|output| {
// We can `expect` here as the request id should always be in the DB
let entry = db let entry = db
.remove(&output.request.unwrap().id) .remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug."); .expect("ID not found in db. This is a bug.");
@ -158,3 +192,18 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
entry.response_tx.send(Ok(output.output)).unwrap_or(()); entry.response_tx.send(Ok(output.output)).unwrap_or(());
}); });
} }
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
}
/// Convert to Axum supported format
impl From<InferError> for (StatusCode, String) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
}
}
}

View File

@ -1,16 +1,173 @@
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::server::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock; use parking_lot::Mutex;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
use tokio::time::Instant;
/// Database entry
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Entry { pub(crate) struct Entry {
/// Request
pub request: GenerateRequest, pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<String, ClientError>>, pub response_tx: Sender<Result<String, ClientError>>,
/// Number of tokens in the input
pub input_length: usize, pub input_length: usize,
/// Instant when this entry was created
pub time: Instant,
}
/// Request Database
#[derive(Debug, Clone)]
pub(crate) struct Db {
pub shared: Arc<Shared>,
}
/// Shared state
#[derive(Debug)]
pub struct Shared {
state: Mutex<State>,
}
/// Database State
#[derive(Debug)]
struct State {
/// Database entries organized in a BTreeMap to be able to iterate over them in order
entries: BTreeMap<u64, Entry>,
/// Id of the next entry
next_id: u64,
/// Id of the next batch
next_batch_id: u64,
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
next_batch_start_id: u64,
}
impl State {
/// Get the next requests
fn next_requests(
&self,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Vec<Request>)> {
// Iterates for max_size over the BTreemap starting from next_batch_start_id
let mut requests = Vec::new();
let mut ids = Vec::new();
for (id, entry) in self
.entries
// Start from next_batch_start_id
.range(self.next_batch_start_id..)
// Take max_size
.take(max_size)
{
if let Some(min_waiting_time) = min_waiting_time {
// Only take entries that waited for at least min_waiting_time
if entry.time.elapsed() < min_waiting_time {
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break;
}
}
requests.push(Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
});
ids.push(*id);
}
if requests.is_empty() {
None
} else {
Some((ids, requests))
}
}
}
impl Db {
pub(crate) fn new() -> Self {
// Shared state
let shared = Arc::new(Shared {
state: Mutex::new(State {
entries: BTreeMap::new(),
next_id: 0,
next_batch_id: 0,
next_batch_start_id: 0,
}),
});
Self { shared }
}
/// Append an entry to the database
pub(crate) fn append(&self, entry: Entry) {
// Acquire lock
let mut state = self.shared.state.lock();
// Insert entry
let id = state.next_id;
state.next_id += 1;
state.entries.insert(id, entry);
}
/// Remove an entry from the database if it exists
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
let mut state = self.shared.state.lock();
state.entries.remove(id)
}
// Get the next batch
pub(crate) fn next_batch(
&self,
min_size: Option<usize>,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Batch)> {
// Acquire lock
let mut state = self.shared.state.lock();
// Get requests from the database
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
if let Some(min_size) = min_size {
// If min_size is set, only return a batch if there are enough requests
if requests.len() < min_size {
return None;
}
}
// Batch size
let size = requests.len();
// Longest input length for all requests in batch size
// Used for padding inside the inference server
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
let batch = Batch {
id: state.next_batch_id,
requests,
size: size as u32,
max_sequence_length,
};
// Update next_batch_start_id to the last id in the batch + 1
state.next_batch_start_id = ids.last().unwrap() + 1;
// Increment batch id
state.next_batch_id += 1;
return Some((ids, batch));
}
None
}
} }
impl From<GenerateParameters> for LogitsWarperParameters { impl From<GenerateParameters> for LogitsWarperParameters {
@ -23,129 +180,3 @@ impl From<GenerateParameters> for LogitsWarperParameters {
} }
} }
} }
#[derive(Debug, Clone)]
pub(crate) struct Db {
pub shared: Arc<Shared>,
}
#[derive(Debug)]
pub struct Shared {
state: RwLock<State>,
}
#[derive(Debug)]
struct State {
entries: BTreeMap<u64, Entry>,
/// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why.
next_id: u64,
next_batch_id: u64,
/// Current batch id
next_batch_start_id: u64,
}
impl Db {
pub(crate) fn new() -> Self {
let shared = Arc::new(Shared {
state: RwLock::new(State {
entries: BTreeMap::new(),
next_id: 0,
next_batch_id: 0,
next_batch_start_id: 0,
}),
});
Self { shared }
}
pub(crate) fn append(&self, entry: Entry) {
let mut state = self.shared.state.write();
let id = state.next_id;
state.next_id += 1;
state.entries.insert(id, entry);
}
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
let mut state = self.shared.state.write();
state.entries.remove(id)
}
pub(crate) fn len(&self) -> usize {
let state = self.shared.state.read();
state.entries.len()
}
fn next_requests(&self, max_size: usize) -> Option<(u64, Vec<Request>)> {
let state = self.shared.state.read();
let requests: Vec<Request> = state
.entries
.range(state.next_batch_start_id..)
.take(max_size)
.map(|(id, entry)| Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
})
.collect();
if requests.is_empty() {
None
} else {
let last_id = requests.last().unwrap().id;
Some((last_id, requests))
}
}
pub(crate) fn next_batch(&self, max_size: usize) -> Option<Batch> {
if let Some((last_id, requests)) = self.next_requests(max_size) {
let mut state = self.shared.state.write();
let size = requests.len();
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
let batch = Batch {
id: state.next_batch_id,
requests,
size: size as u32,
max_sequence_length,
};
state.next_batch_start_id = last_id + 1;
state.next_batch_id += 1;
return Some(batch);
}
None
}
pub(crate) fn next_batch_minimum_size(
&self,
min_size: usize,
max_size: usize,
) -> Option<Batch> {
if let Some((last_id, requests)) = self.next_requests(max_size) {
if requests.len() >= min_size {
let mut state = self.shared.state.write();
let size = requests.len();
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
let batch = Batch {
id: state.next_batch_id,
requests,
size: size as u32,
max_sequence_length,
};
state.next_batch_start_id = last_id + 1;
state.next_batch_id += 1;
return Some(batch);
}
}
None
}
}

View File

@ -1,8 +1,68 @@
/// Text Generation Inference Webserver
mod batcher; mod batcher;
mod db; mod db;
mod validation;
pub mod server; pub mod server;
mod validation;
use db::{Db, Entry};
use batcher::Batcher; use batcher::Batcher;
use db::{Db, Entry};
use serde::{Deserialize, Serialize};
use validation::Validation; use validation::Validation;
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_k")]
pub top_k: i32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
}
#[derive(Serialize)]
pub(crate) struct GeneratedText {
pub generated_text: String,
}
pub(crate) type GenerateResponse = Vec<GeneratedText>;

View File

@ -1,37 +1,61 @@
/// Text Generation Inference webserver entrypoint
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use clap::Parser;
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
#[clap(default_value = "32", long, short, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)] #[clap(default_value = "/tmp/bloom-inference-0", long, env)]
shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests,
max_input_length,
max_batch_size, max_batch_size,
max_waiting_time,
port, port,
shard_uds_path, master_shard_uds_path,
tokenizer_name, tokenizer_name,
validation_workers,
} = args; } = args;
if validation_workers == 1 {
panic!("validation_workers must be > 0");
}
let max_waiting_time = Duration::from_secs(max_waiting_time);
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
// We need to download it outside of the Tokio runtime
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let sharded_client = ShardedClient::connect_uds(shard_uds_path) // Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
.expect("Could not connect to server"); .expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client sharded_client
.clear_cache() .clear_cache()
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
// Binds on localhost
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
server::run(max_batch_size, sharded_client, tokenizer, addr).await; // Run server
server::run(
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
sharded_client,
tokenizer,
validation_workers,
addr,
)
.await;
Ok(()) Ok(())
}) })
} }

View File

@ -1,68 +1,44 @@
use crate::{Batcher, Validation}; use crate::{
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::Semaphore;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
#[derive(Clone, Debug, Deserialize)] // Server shared state
pub(crate) struct GenerateParameters { #[derive(Clone)]
#[serde(default = "default_temperature")] struct ServerState {
pub temperature: f32, validation: Validation,
#[serde(default = "default_top_k")] batcher: Batcher,
pub top_k: i32, limit_concurrent_requests: Arc<Semaphore>,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
} }
/// Health check method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> { async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
"Model is overloaded".to_string(),
)
})?;
// Send a small inference request
state state
.batcher .batcher
.infer( .infer(
@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri
Ok(()) Ok(())
} }
/// Generate method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn generate( async fn generate(
state: Extension<ServerState>, state: Extension<ServerState>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> { ) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
let start = Instant::now(); let start = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
"Model is overloaded".to_string(),
)
})?;
// Validate request
let (input_length, validated_request) = state let (input_length, validated_request) = state
.validation .validation
// FIXME: can't we get rid of the cloning here??
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: req.inputs.clone(), inputs: req.inputs.clone(),
parameters: req.parameters.clone(), parameters: req.parameters.clone(),
}) })
.await?; .await?;
// Inference
let generated_text = state.batcher.infer(input_length, validated_request).await?; let generated_text = state.batcher.infer(input_length, validated_request).await?;
// Tracing metadata
tracing::Span::current().record("time", format!("{:?}", start.elapsed())); tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record( tracing::Span::current().record(
"time_per_token", "time_per_token",
@ -106,31 +94,71 @@ async fn generate(
); );
tracing::info!("response: {}", generated_text); tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({ // Send response
"generated_text": generated_text, let response = vec![GeneratedText { generated_text }];
}))) Ok(Json(response))
} }
#[derive(Clone)] /// Serving method
struct ServerState { #[allow(clippy::too_many_arguments)]
validation: Validation, pub async fn run(
batcher: Batcher, max_concurrent_requests: usize,
} max_input_length: usize,
max_batch_size: usize,
pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) { max_waiting_time: Duration,
let batcher = Batcher::new(client, max_batch_size); client: ShardedClient,
let validation = Validation::new(tokenizer); tokenizer: Tokenizer,
validation_workers: usize,
let shared_state = ServerState { validation, batcher }; addr: SocketAddr,
) {
// Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState {
validation,
batcher,
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
};
// Create router
let app = Router::new() let app = Router::new()
.route("/generate", post(generate)) .route("/generate", post(generate))
.layer(Extension(shared_state.clone())) .layer(Extension(shared_state.clone()))
.route("/health", get(liveness)) .route("/health", get(health))
.layer(Extension(shared_state.clone())); .layer(Extension(shared_state.clone()));
// Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()) .serve(app.into_make_service())
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await .await
.unwrap(); .unwrap();
} }
/// Shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::info!("signal received, starting graceful shutdown");
}

View File

@ -1,62 +1,105 @@
use crate::server::GenerateRequest; /// Payload validation logic
use crate::GenerateRequest;
use axum::http::StatusCode; use axum::http::StatusCode;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::{
DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
TokenizerImpl,
};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
#[derive(Error, Debug)] /// Validation
pub enum ValidationError {
#[error("Temperature must be strictly positive")]
Temperature,
#[error("Top p must be <= 0.0 or > 1.0")]
TopP,
#[error("Top k must be strictly positive")]
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
InputLength(usize),
}
impl From<ValidationError> for (StatusCode, String) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
}
}
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
/// Channel to communicate with the background validation task
sender: mpsc::Sender<ValidationRequest>, sender: mpsc::Sender<ValidationRequest>,
} }
impl Validation { impl Validation {
pub(crate) fn new(tokenizer: Tokenizer) -> Self { pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
// Crate channel
let (validation_sender, validation_receiver) = mpsc::channel(128); let (validation_sender, validation_receiver) = mpsc::channel(128);
tokio::spawn(validation_task(tokenizer, validation_receiver)); // Launch background validation task
tokio::spawn(validation_task(
workers,
tokenizer,
max_input_length,
validation_receiver,
));
Self { Self {
sender: validation_sender, sender: validation_sender,
} }
} }
/// Validate a payload and get the number of tokens in the input
pub(crate) async fn validate( pub(crate) async fn validate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<(usize, GenerateRequest), ValidationError> { ) -> Result<(usize, GenerateRequest), ValidationError> {
// Create response channel
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender.send((request, sender)).await.unwrap(); self.sender.send((request, sender)).await.unwrap();
// Await on response channel
// Unwrap is safe here
receiver.await.unwrap() receiver.await.unwrap()
} }
} }
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) { /// Validation task
while let Some((request, response_tx)) = receiver.recv().await { /// Load balance the validation requests between multiple validation workers
async fn validation_task(
workers: usize,
tokenizer: Tokenizer,
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
let mut workers_senders = Vec::with_capacity(workers);
// Create workers
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
// Create channel to communicate with worker
let (worker_sender, worker_receiver) = mpsc::channel(workers);
workers_senders.push(worker_sender);
// Spawn worker
tokio::task::spawn_blocking(move || {
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
});
}
loop {
// Load balance requests between workers
for sender in workers_senders.iter() {
if let Some(validation_request) = receiver.recv().await {
sender.send(validation_request).await.unwrap();
} else {
return;
}
}
}
}
/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
tokenizer: TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>,
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() {
if request.parameters.temperature < 0.0 { if request.parameters.temperature < 0.0 {
response_tx response_tx
.send(Err(ValidationError::Temperature)) .send(Err(ValidationError::Temperature))
@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
continue; continue;
} }
// Get the number of tokens in the input
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap(); let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
let input_length = inputs.len(); let input_length = inputs.len();
if input_length > 1000 { if input_length > max_input_length {
response_tx response_tx
.send(Err(ValidationError::InputLength(input_length))) .send(Err(ValidationError::InputLength(input_length)))
.unwrap_or(()); .unwrap_or(());
@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
response_tx.send(Ok((input_length, request))).unwrap_or(()); response_tx.send(Ok((input_length, request))).unwrap_or(());
} }
} }
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Temperature must be strictly positive")]
Temperature,
#[error("Top p must be <= 0.0 or > 1.0")]
TopP,
#[error("Top k must be strictly positive")]
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
InputLength(usize),
}
impl From<ValidationError> for (StatusCode, String) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
}
}

30
run.sh
View File

@ -1,30 +0,0 @@
#!/usr/bin/env bash
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
# Run in background
$server_cmd 2>&1 > /dev/null &
# Check if server is running by checking if the unix socket is created
FILE=/tmp/bloom-inference-0
while :
do
if test -S "$FILE"; then
echo "Text Generation Python gRPC server started"
break
else
echo "Waiting for Text Generation Python gRPC server to start"
sleep 5
fi
done
sleep 1
# Run in background
text-generation-router &
# Wait for any process to exit
wait -n
# Exit with status of process that exited first
exit $?

155
server/.gitignore vendored Normal file
View File

@ -0,0 +1,155 @@
# Byte-compiled / optimized / DLL files
__pycache__/
bloom_inference/__pycache__/
bloom_inference/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/

View File

@ -4,17 +4,28 @@ gen-server:
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch bloom_inference/pb/__init__.py touch bloom_inference/pb/__init__.py
unit-tests: install-transformers:
python -m pytest --cov=bloom_inference tests # Install specific version of transformers
rm transformers || true
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers
cd transformers && python setup.py install
unit-tests-reporting: install-torch:
python -m pytest --junitxml=report.xml --cov=bloom_inference tests # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
pip-install: pip-install:
pip install grpcio-tools pip install grpcio-tools
make gen-server make gen-server
make install-torch
make install-transformers
pip install . pip install .
install: install:
poetry install poetry install
make gen-server make gen-server
make install-torch
make install-transformers

View File

@ -1,41 +1,51 @@
import os
import typer import typer
from pathlib import Path from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional from typing import Optional
from bloom_inference import server from bloom_inference import prepare_weights, server
app = typer.Typer() app = typer.Typer()
@app.command()
def launcher(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, server.serve, [model_name, True, shard_directory])
@app.command() @app.command()
def serve( def serve(
model_name: str, model_name: str,
sharded: bool = False, sharded: bool = False,
shard_directory: Optional[Path] = None, shard_directory: Optional[Path] = None,
uds_path: Path = "/tmp/bloom-inference",
): ):
server.serve(model_name, sharded, shard_directory) if sharded:
assert (
shard_directory is not None
), "shard_directory must be set when sharded is True"
assert (
os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, sharded, uds_path, shard_directory)
@app.command()
def prepare_weights(
model_name: str,
shard_directory: Path,
cache_directory: Path,
num_shard: int = 1,
):
prepare_weights.prepare_weights(
model_name, cache_directory, shard_directory, num_shard
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,6 +24,7 @@ torch.manual_seed(0)
class Batch: class Batch:
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor] input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor] all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
@ -46,12 +47,12 @@ class Batch:
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = [] all_input_lengths = []
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) all_input_lengths.append(r.input_length)
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser( NextTokenChooser(
temperature=r.parameters.temperature, temperature=r.parameters.temperature,
@ -63,17 +64,12 @@ class Batch:
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device) input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
# Remove padding from all_input_ids all_input_ids = input_ids["input_ids"].unsqueeze(-1)
all_input_ids = [
input_ids.squeeze(0)[-length:].unsqueeze(-1)
for length, input_ids in zip(
input_lengths, input_ids["input_ids"].split(1, dim=0)
)
]
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids, input_ids=input_ids,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -91,6 +87,7 @@ class Batch:
# Batch attributes # Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = [] requests = []
all_input_lengths = []
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
@ -100,6 +97,7 @@ class Batch:
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
@ -198,6 +196,7 @@ class Batch:
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids, input_ids=input_ids,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
@ -227,7 +226,10 @@ class BLOOM:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = ( self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype) AutoModelForCausalLM.from_pretrained(model_name)
.eval()
.to(self.device)
.to(dtype)
) )
self.num_heads = self.model.base_model.num_heads self.num_heads = self.model.base_model.num_heads
@ -253,6 +255,7 @@ class BLOOM:
# New input_ids for next forward # New input_ids for next forward
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_sequence_length = 0
@ -263,6 +266,7 @@ class BLOOM:
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.all_input_lengths,
outputs.logits, outputs.logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
@ -272,6 +276,7 @@ class BLOOM:
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
@ -302,8 +307,10 @@ class BLOOM:
next_batch_input_ids.append(next_token) next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens) next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1 next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max( next_batch_max_sequence_length = max(
next_batch_max_sequence_length, len(all_tokens) next_batch_max_sequence_length, new_input_length
) )
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
@ -350,6 +357,7 @@ class BLOOM:
next_batch = Batch( next_batch = Batch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids, input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM):
if self.master: if self.master:
# TODO @thomasw21 do some caching # TODO @thomasw21 do some caching
shard_state_dict_paths = prepare_weights( shard_state_dict_paths = prepare_weights(
model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size model_name,
shard_directory / "cache",
shard_directory,
tp_world_size=self.world_size,
) )
shard_state_dict_paths = [ shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths str(path.absolute()) for path in shard_state_dict_paths
@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them
logits_shard = outputs.logits[:, -1, :].contiguous() logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape batch_size, vocab_shard_size = logits_shard.shape

View File

@ -14,7 +14,7 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
def match_suffix(text, suffix): def match_suffix(text, suffix):
return text[-len(suffix):] == suffix return text[-len(suffix) :] == suffix
def http_get( def http_get(
@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path):
return filename return filename
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int): def prepare_weights(
model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
):
save_paths = [ save_paths = [
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty" save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size) for tp_rank in range(tp_world_size)
@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
if model_name == "bigscience/bloom-560m": if model_name == "bigscience/bloom-560m":
url = hf_hub_url(model_name, filename="pytorch_model.bin") url = hf_hub_url(model_name, filename="pytorch_model.bin")
cache_download_url(url, cache_path) cache_download_url(url, cache_path)
elif model_name == "bigscience/bloom": elif model_name == "bigscience/bloom":
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json") url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
index_path = cache_download_url(url, cache_path) index_path = cache_download_url(url, cache_path)
@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
index = json.load(f) index = json.load(f)
# Get unique file names # Get unique file names
weight_files = list(set([filename for filename in index["weight_map"].values()])) weight_files = list(
set([filename for filename in index["weight_map"].values()])
)
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files] urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)) Parallel(n_jobs=5)(
delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
)
else: else:
raise ValueError(f"Unknown model name: {model_name}") raise ValueError(f"Unknown model name: {model_name}")
@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
assert len(sharded_weights) == tp_world_size assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights): for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone() shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif match_suffix(state_name, "lm_head.weight"): elif match_suffix(state_name, "lm_head.weight"):
output_size = state.shape[0] output_size = state.shape[0]
@ -132,7 +141,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
sharded_weights = torch.split(state, block_size, dim=1) sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights): for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone() shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any( elif any(
match_suffix(state_name, candidate) match_suffix(state_name, candidate)
@ -141,14 +152,20 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
"mlp.dense_4h_to_h.bias", "mlp.dense_4h_to_h.bias",
] ]
): ):
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone() shards_state_dicts[0][
"transformer." + state_name
] = state.detach().clone()
for tp_rank in range(1, tp_world_size): for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state) shards_state_dicts[tp_rank][
"transformer." + state_name
] = torch.zeros_like(state)
else: else:
# We duplicate parameters across tp ranks # We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size): for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone() shards_state_dicts[tp_rank][
"transformer." + state_name
] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict del state_dict[state_name] # delete key from state_dict
del state # delete tensor del state # delete tensor
@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
torch.save(shard_state_dict, save_path) torch.save(shard_state_dict, save_path)
return save_paths return save_paths
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model-name", required=True, type=str)
parser.add_argument("--cache-path", required=True, type=str)
parser.add_argument("--save-path", required=True, type=str)
parser.add_argument("--world-size", required=True, type=int)
args = parser.parse_args()
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)

View File

@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
async def GenerateUntilFinished(self, request, context):
batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)
generated_texts = [] def serve(
while not generated_texts: model_name: str,
generated_texts, next_batch = self.model.generate_token(batch) sharded: bool,
batch = next_batch uds_path: Path,
self.cache.set(next_batch) shard_directory: Optional[Path] = None,
):
return generate_pb2.GenerateUntilFinishedResponse(
generated_texts=[
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None,
)
async def GenerateUntilFinishedWithCache(self, request, context):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batches = []
for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
if len(batches) > 1:
batch = Batch.concatenate(batches)
else:
batch = batches[0]
generated_texts = []
while not generated_texts:
generated_texts, next_batch = self.model.generate_token(batch)
batch = next_batch
self.cache.set(next_batch)
return generate_pb2.GenerateUntilFinishedWithCacheResponse(
generated_texts=[
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None,
)
def serve(model_name, sharded, shard_directory):
async def serve_inner( async def serve_inner(
model_name: str, model_name: str,
sharded: bool = False, sharded: bool = False,
shard_directory: Optional[Path] = None, shard_directory: Optional[Path] = None,
): ):
unix_socket_template = "unix:///tmp/bloom-inference-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded:
if shard_directory is None: if shard_directory is None:
raise ValueError("shard_directory must be set when sharded is True") raise ValueError("shard_directory must be set when sharded is True")
model = BLOOMSharded(model_name, shard_directory) model = BLOOMSharded(model_name, shard_directory)
server_urls = [ server_urls = [
unix_socket_template.format(rank) for rank in range(model.world_size) unix_socket_template.format(uds_path, rank)
for rank in range(model.world_size)
] ]
local_url = unix_socket_template.format(model.rank) local_url = server_urls[model.rank]
else: else:
model = BLOOM(model_name) model = BLOOM(model_name)
local_url = unix_socket_template.format(0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
server = aio.server() server = aio.server()
@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory):
server.add_insecure_port(local_url) server.add_insecure_port(local_url)
await server.start() await server.start()
print("Server started at {}".format(local_url)) print("Server started at {}".format(local_url))
try:
await server.wait_for_termination() await server.wait_for_termination()
except KeyboardInterrupt:
print("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_name, sharded, shard_directory)) asyncio.run(serve_inner(model_name, sharded, shard_directory))

View File

@ -82,7 +82,6 @@ def initialize_torch_distributed():
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
init_method="tcp://localhost:6000",
) )
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size return torch.distributed.distributed_c10d._get_default_group(), rank, world_size

2
server/poetry.lock generated
View File

@ -205,7 +205,7 @@ python-versions = ">=3.7"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2" content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [

View File

@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app'
python = "^3.9" python = "^3.9"
protobuf = "^4.21.7" protobuf = "^4.21.7"
grpcio = "^1.49.1" grpcio = "^1.49.1"
torch = "^1.12.1"
typer = "^0.6.1" typer = "^0.6.1"
grpcio-reflection = "^1.49.1" grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0" accelerate = "^0.12.0"