v0.1.0
This commit is contained in:
parent
92c1ecd008
commit
f16f2f5ae1
|
@ -1,2 +1,2 @@
|
|||
aml
|
||||
router/target
|
||||
target
|
|
@ -1 +1,2 @@
|
|||
.idea
|
||||
target
|
|
@ -55,9 +55,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.57"
|
||||
version = "0.1.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f"
|
||||
checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -166,9 +166,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.11.0"
|
||||
version = "3.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d"
|
||||
checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
|
@ -255,9 +255,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.0.15"
|
||||
version = "4.0.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
|
||||
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags",
|
||||
|
@ -391,6 +391,16 @@ dependencies = [
|
|||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctrlc"
|
||||
version = "3.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d91974fbbe88ec1df0c24a4f00f99583667a7e2e6272b2b92d294d81e462173"
|
||||
dependencies = [
|
||||
"nix",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.10.2"
|
||||
|
@ -529,7 +539,7 @@ dependencies = [
|
|||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"windows-sys",
|
||||
"windows-sys 0.36.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -936,9 +946,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.3"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754"
|
||||
checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
|
@ -957,9 +967,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
|||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.134"
|
||||
version = "0.2.135"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
|
||||
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
|
@ -1047,7 +1057,7 @@ dependencies = [
|
|||
"libc",
|
||||
"log",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"windows-sys",
|
||||
"windows-sys 0.36.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1074,6 +1084,18 @@ dependencies = [
|
|||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.25.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bitflags",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.1"
|
||||
|
@ -1084,6 +1106,16 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||
dependencies = [
|
||||
"overload",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.1"
|
||||
|
@ -1185,6 +1217,12 @@ version = "6.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
|
@ -1197,15 +1235,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.3"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
|
||||
checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-sys",
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1300,9 +1338,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.46"
|
||||
version = "1.0.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
|
||||
checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
@ -1530,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"windows-sys",
|
||||
"windows-sys 0.36.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1584,9 +1622,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.85"
|
||||
version = "1.0.86"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44"
|
||||
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
|
@ -1625,6 +1663,15 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.7"
|
||||
|
@ -1681,10 +1728,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.101"
|
||||
name = "subprocess"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2"
|
||||
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.102"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1741,13 +1798,24 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap 4.0.17",
|
||||
"ctrlc",
|
||||
"subprocess",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"bloom-inference-client",
|
||||
"clap 4.0.15",
|
||||
"clap 4.0.17",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
|
@ -1872,6 +1940,7 @@ dependencies = [
|
|||
"num_cpus",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"winapi",
|
||||
|
@ -1910,9 +1979,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.10"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af"
|
||||
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
|
@ -2031,9 +2100,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
|
||||
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
|
@ -2043,9 +2112,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
|
|||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.36"
|
||||
version = "0.1.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307"
|
||||
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"log",
|
||||
|
@ -2056,9 +2125,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.22"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2"
|
||||
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -2067,9 +2136,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.29"
|
||||
version = "0.1.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7"
|
||||
checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"valuable",
|
||||
|
@ -2108,11 +2177,11 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.15"
|
||||
version = "0.3.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b"
|
||||
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
|
||||
dependencies = [
|
||||
"ansi_term",
|
||||
"nu-ansi-term",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
|
@ -2140,9 +2209,9 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992"
|
|||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd"
|
||||
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
|
@ -2361,43 +2430,100 @@ version = "0.36.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
|
||||
dependencies = [
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_msvc",
|
||||
"windows_aarch64_msvc 0.36.1",
|
||||
"windows_i686_gnu 0.36.1",
|
||||
"windows_i686_msvc 0.36.1",
|
||||
"windows_x86_64_gnu 0.36.1",
|
||||
"windows_x86_64_msvc 0.36.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc 0.42.0",
|
||||
"windows_i686_gnu 0.42.0",
|
||||
"windows_i686_msvc 0.42.0",
|
||||
"windows_x86_64_gnu 0.42.0",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.36.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.36.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.36.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.36.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.36.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
|
@ -0,0 +1,11 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"router",
|
||||
"router/client",
|
||||
"launcher"
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
incremental = true
|
||||
lto = "off"
|
36
Dockerfile
36
Dockerfile
|
@ -1,4 +1,4 @@
|
|||
FROM rust:1.64 as builder
|
||||
FROM rust:1.64 as router-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -9,7 +9,17 @@ WORKDIR /usr/src/router
|
|||
|
||||
RUN cargo install --path .
|
||||
|
||||
FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
|
||||
FROM rust:1.64 as launcher-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY launcher launcher
|
||||
|
||||
WORKDIR /usr/src/launcher
|
||||
|
||||
RUN cargo install --path .
|
||||
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
|
||||
ENV LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
|
@ -34,17 +44,15 @@ RUN cd ~ && \
|
|||
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
|
||||
conda create -n text-generation python=3.9 -y
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile server/Makefile
|
||||
|
||||
# Install specific version of torch
|
||||
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
||||
RUN cd server && make install-torch
|
||||
|
||||
# Install specific version of transformers
|
||||
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
||||
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
|
||||
/opt/miniconda/envs/text-generation/bin/python setup.py install
|
||||
|
||||
WORKDIR /usr/src
|
||||
RUN cd server && make install-transformers
|
||||
|
||||
# Install server
|
||||
COPY server server
|
||||
|
@ -52,9 +60,7 @@ RUN cd server && \
|
|||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||
|
||||
# Install router
|
||||
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
COPY run.sh .
|
||||
RUN chmod +x run.sh
|
||||
|
||||
CMD ["./run.sh"]
|
||||
CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH
|
|
@ -0,0 +1,19 @@
|
|||
install-server:
|
||||
cd server && make pip-install
|
||||
|
||||
install-router:
|
||||
cd router && cargo install --path .
|
||||
|
||||
install-launcher:
|
||||
cd launcher && cargo install --path .
|
||||
|
||||
install:
|
||||
make install-server
|
||||
make install-router
|
||||
make install-launcher
|
||||
|
||||
run-bloom-560m:
|
||||
text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2
|
||||
|
||||
run-bloom:
|
||||
text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8
|
51
README.md
51
README.md
|
@ -1,50 +1,51 @@
|
|||
# Text Generation Inference
|
||||
# LLM Text Generation Inference
|
||||
|
||||
A Rust and gRPC server for text generation inference.
|
||||
<div align="center">
|
||||
|
||||
## Load Tests
|
||||
![architecture](assets/architecture.jpg)
|
||||
|
||||
</div>
|
||||
|
||||
A Rust and gRPC server for large language models text generation inference.
|
||||
|
||||
## Load Tests for BLOOM
|
||||
|
||||
See `k6/load_test.js`
|
||||
We send the default examples with a 1 second delay between each request.
|
||||
We send the default examples with a 1 second delay between requests.
|
||||
|
||||
Stages:
|
||||
- Ramp up to 50 concurrent requests per second in 1min
|
||||
- Ramp up from 50 to 100 concurrent requests per second in 2min
|
||||
- Ramp down to 0 concurrent requests per second in 1min
|
||||
- Ramp up to 50 vus in 1min
|
||||
- Ramp up from 50 to 100 vus in 2min
|
||||
- Ramp down to 0 vus in 1min
|
||||
|
||||
|
||||
| | avg | min | med | max | p(90) | p(95) | RPS |
|
||||
|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------|
|
||||
| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
||||
| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
|
||||
| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
||||
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
|
||||
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
||||
| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
|
||||
| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
||||
|
||||
## Install
|
||||
|
||||
```shell
|
||||
cd server
|
||||
pip install .
|
||||
```
|
||||
|
||||
```
|
||||
cd router
|
||||
cargo build --release
|
||||
make install
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
```shell
|
||||
python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models
|
||||
make run-bloom-560m
|
||||
```
|
||||
|
||||
## Test
|
||||
|
||||
```shell
|
||||
./router/target/release/router
|
||||
curl 127.0.0.1:3000/generate \
|
||||
-X POST \
|
||||
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
```
|
||||
|
||||
## TODO:
|
||||
|
||||
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
|
||||
- [ ] Add tests
|
||||
- [ ] Add shutdown logic in router and server
|
||||
- [ ] Improve multi-processing logic in server
|
||||
- [ ] Improve past key layer indexing?
|
||||
- [ ] Add tests for the `server/model` logic
|
|
@ -8,7 +8,7 @@ environment_variables:
|
|||
MODEL_NAME: bigscience/bloom
|
||||
NUM_GPUS: 8
|
||||
environment:
|
||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
|
||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2
|
||||
inference_config:
|
||||
liveness_route:
|
||||
port: 3000
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 132 KiB |
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Launcher"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
ctrlc = "3.2.3"
|
||||
subprocess = "0.2.9"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = "0.3.16"
|
|
@ -0,0 +1,358 @@
|
|||
use clap::Parser;
|
||||
use std::io::{BufRead, BufReader, Read};
|
||||
use std::path::Path;
|
||||
use std::process::ExitCode;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc::TryRecvError;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{mpsc, Mutex};
|
||||
use std::thread;
|
||||
use std::thread::sleep;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fs, io};
|
||||
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||
model_name: String,
|
||||
#[clap(long, env)]
|
||||
num_shard: Option<usize>,
|
||||
#[clap(long, env)]
|
||||
shard_directory: Option<String>,
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "1000", long, env)]
|
||||
max_input_length: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_waiting_time: u64,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
||||
shard_uds_path: String,
|
||||
#[clap(default_value = "localhost", long, env)]
|
||||
master_addr: String,
|
||||
#[clap(default_value = "29500", long, env)]
|
||||
master_port: usize,
|
||||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
model_name,
|
||||
num_shard,
|
||||
shard_directory,
|
||||
max_concurrent_requests,
|
||||
max_input_length,
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_addr,
|
||||
master_port,
|
||||
} = Args::parse();
|
||||
|
||||
// By default we only have one master shard
|
||||
let num_shard = num_shard.unwrap_or(1);
|
||||
|
||||
// Signal handler
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let r = running.clone();
|
||||
ctrlc::set_handler(move || {
|
||||
r.store(false, Ordering::SeqCst);
|
||||
})
|
||||
.expect("Error setting Ctrl-C handler");
|
||||
|
||||
// Shared shutdown bool
|
||||
let shutdown = Arc::new(Mutex::new(false));
|
||||
// Shared shutdown channel
|
||||
// When shutting down, the main thread will wait for all senders to be dropped
|
||||
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||
|
||||
// Shared channel to track shard status
|
||||
let (status_sender, status_receiver) = mpsc::channel();
|
||||
|
||||
// Start shard processes
|
||||
for rank in 0..num_shard {
|
||||
let model_name = model_name.clone();
|
||||
let uds_path = shard_uds_path.clone();
|
||||
let shard_directory = shard_directory.clone();
|
||||
let master_addr = master_addr.clone();
|
||||
let status_sender = status_sender.clone();
|
||||
let shutdown = shutdown.clone();
|
||||
let shutdown_sender = shutdown_sender.clone();
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_name,
|
||||
uds_path,
|
||||
shard_directory,
|
||||
rank,
|
||||
num_shard,
|
||||
master_addr,
|
||||
master_port,
|
||||
status_sender,
|
||||
shutdown,
|
||||
shutdown_sender,
|
||||
)
|
||||
});
|
||||
}
|
||||
drop(shutdown_sender);
|
||||
|
||||
// Wait for shard to start
|
||||
let mut shard_ready = 0;
|
||||
while running.load(Ordering::SeqCst) {
|
||||
match status_receiver.try_recv() {
|
||||
Ok(ShardStatus::Ready) => {
|
||||
shard_ready += 1;
|
||||
if shard_ready == num_shard {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(TryRecvError::Empty) => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
Ok(ShardStatus::Failed((rank, err))) => {
|
||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
tracing::error!("Shard status channel disconnected");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We might have received a termination signal
|
||||
if !running.load(Ordering::SeqCst) {
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::SUCCESS;
|
||||
}
|
||||
|
||||
// All shard started
|
||||
// Start webserver
|
||||
tracing::info!("Starting Webserver");
|
||||
let mut webserver = match Popen::create(
|
||||
&[
|
||||
"text-generation-router",
|
||||
"--max-concurrent-requests",
|
||||
&max_concurrent_requests.to_string(),
|
||||
"--max-input-length",
|
||||
&max_input_length.to_string(),
|
||||
"--max-batch-size",
|
||||
&max_batch_size.to_string(),
|
||||
"--max-waiting-time",
|
||||
&max_waiting_time.to_string(),
|
||||
"--port",
|
||||
&port.to_string(),
|
||||
"--master-shard-uds-path",
|
||||
&format!("{}-0", shard_uds_path),
|
||||
"--tokenizer-name",
|
||||
&model_name,
|
||||
],
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to start webserver: {}", err);
|
||||
if let PopenError::IoError(err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("text-generation-router not found in PATH");
|
||||
tracing::error!("Please install it with `make install-router`")
|
||||
}
|
||||
}
|
||||
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
};
|
||||
|
||||
// Redirect STDOUT and STDERR to the console
|
||||
let webserver_stdout = webserver.stdout.take().unwrap();
|
||||
let webserver_stderr = webserver.stderr.take().unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
let stdout = BufReader::new(webserver_stdout);
|
||||
let stderr = BufReader::new(webserver_stderr);
|
||||
for line in stdout.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
for line in stderr.lines() {
|
||||
println!("{}", line.unwrap());
|
||||
}
|
||||
});
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = ExitCode::SUCCESS;
|
||||
|
||||
while running.load(Ordering::SeqCst) {
|
||||
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||
tracing::error!("Shard {} failed:\n{}", rank, err);
|
||||
exit_code = ExitCode::FAILURE;
|
||||
break;
|
||||
};
|
||||
|
||||
match webserver.poll() {
|
||||
Some(_) => {
|
||||
tracing::error!("Webserver Crashed");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
return ExitCode::FAILURE;
|
||||
}
|
||||
None => {
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Graceful termination
|
||||
webserver.terminate().unwrap();
|
||||
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
||||
tracing::info!("Webserver terminated");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
|
||||
exit_code
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ShardStatus {
|
||||
Ready,
|
||||
Failed((usize, String)),
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn shard_manager(
|
||||
model_name: String,
|
||||
uds_path: String,
|
||||
shard_directory: Option<String>,
|
||||
rank: usize,
|
||||
world_size: usize,
|
||||
master_addr: String,
|
||||
master_port: usize,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
_shutdown_sender: mpsc::Sender<()>,
|
||||
) {
|
||||
// Get UDS path
|
||||
let uds_string = format!("{}-{}", uds_path, rank);
|
||||
let uds = Path::new(&uds_string);
|
||||
// Clean previous runs
|
||||
fs::remove_file(uds).unwrap_or_default();
|
||||
|
||||
// Process args
|
||||
let mut shard_argv = vec![
|
||||
"bloom-inference-server".to_string(),
|
||||
"serve".to_string(),
|
||||
model_name,
|
||||
"--uds-path".to_string(),
|
||||
uds_path,
|
||||
];
|
||||
|
||||
if world_size > 1 {
|
||||
shard_argv.push("--sharded".to_string());
|
||||
}
|
||||
|
||||
if let Some(shard_directory) = shard_directory {
|
||||
shard_argv.push("--shard-directory".to_string());
|
||||
shard_argv.push(shard_directory);
|
||||
}
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {}", rank);
|
||||
let mut p = match Popen::create(
|
||||
&shard_argv,
|
||||
PopenConfig {
|
||||
stdout: Redirection::Pipe,
|
||||
stderr: Redirection::Pipe,
|
||||
// Needed for the shutdown procedure
|
||||
setpgid: true,
|
||||
// NCCL env vars
|
||||
env: Some(vec![
|
||||
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
|
||||
(
|
||||
"WORLD_SIZE".parse().unwrap(),
|
||||
world_size.to_string().parse().unwrap(),
|
||||
),
|
||||
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
|
||||
(
|
||||
"MASTER_PORT".parse().unwrap(),
|
||||
master_port.to_string().parse().unwrap(),
|
||||
),
|
||||
]),
|
||||
..Default::default()
|
||||
},
|
||||
) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("bloom-inference-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, err.to_string())))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut ready = false;
|
||||
let start_time = Instant::now();
|
||||
loop {
|
||||
// Process exited
|
||||
if p.poll().is_some() {
|
||||
let mut err = String::new();
|
||||
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
||||
status_sender
|
||||
.send(ShardStatus::Failed((rank, err)))
|
||||
.unwrap();
|
||||
return;
|
||||
}
|
||||
|
||||
// We received a shutdown signal
|
||||
if *shutdown.lock().unwrap() {
|
||||
p.terminate().unwrap();
|
||||
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||
tracing::info!("Shard {} terminated", rank);
|
||||
return;
|
||||
}
|
||||
|
||||
// Shard is ready
|
||||
if uds.exists() && !ready {
|
||||
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
||||
status_sender.send(ShardStatus::Ready).unwrap();
|
||||
ready = true;
|
||||
} else if !ready {
|
||||
tracing::info!("Waiting for shard {} to be ready...", rank);
|
||||
}
|
||||
sleep(Duration::from_secs(5));
|
||||
}
|
||||
}
|
||||
|
||||
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
|
||||
tracing::info!("Shutting down shards");
|
||||
// Update shutdown value to true
|
||||
// This will be picked up by the shard manager
|
||||
{
|
||||
let mut shutdown = shutdown.lock().unwrap();
|
||||
*shutdown = true;
|
||||
}
|
||||
|
||||
// Wait for shards to shutdown
|
||||
// This will block till all shutdown_sender are dropped
|
||||
let _ = shutdown_receiver.recv();
|
||||
}
|
|
@ -11,10 +11,6 @@ service TextGenerationService {
|
|||
rpc Generate (GenerateRequest) returns (GenerateResponse);
|
||||
/// Generate tokens for a list of cached batches
|
||||
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
|
||||
/// Generate tokens until the text of at least one request of the batch is generated
|
||||
rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse);
|
||||
/// Generate tokens until the text of at least one request of the cached batches i finished
|
||||
rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse);
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
|
@ -92,27 +88,3 @@ message GenerateWithCacheResponse {
|
|||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
|
||||
message GenerateUntilFinishedRequest {
|
||||
/// Batch
|
||||
Batch batch = 1;
|
||||
}
|
||||
|
||||
message GenerateUntilFinishedResponse {
|
||||
/// Finished requests
|
||||
repeated GeneratedText generated_texts = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
|
||||
message GenerateUntilFinishedWithCacheRequest {
|
||||
/// Cached batches
|
||||
repeated Batch batches = 1;
|
||||
}
|
||||
|
||||
message GenerateUntilFinishedWithCacheResponse {
|
||||
/// Finished requests
|
||||
repeated GeneratedText generated_texts = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
/target
|
|
@ -22,16 +22,7 @@ serde = "1.0.145"
|
|||
serde_json = "1.0.85"
|
||||
thiserror = "1.0.37"
|
||||
tokenizers = "0.13.0"
|
||||
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
|
||||
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tracing = "0.1.36"
|
||||
tracing-subscriber = "0.3.15"
|
||||
|
||||
[workspace]
|
||||
members = [
|
||||
"client",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
incremental = true
|
||||
lto = "off"
|
||||
|
|
|
@ -5,8 +5,6 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
futures = "^0.3"
|
||||
#grpc-error-details = { path = "../../grpc-error-details" }
|
||||
#grpc-metadata = { path = "../../grpc-metadata" }
|
||||
prost = "^0.9"
|
||||
thiserror = "^1.0"
|
||||
tokio = { version = "^1.21", features = ["sync"] }
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
/// Single shard Client
|
||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v1::*;
|
||||
use crate::Result;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::*;
|
||||
|
||||
/// BLOOM Inference gRPC client
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
|
@ -34,6 +35,7 @@ impl Client {
|
|||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {});
|
||||
|
@ -46,6 +48,7 @@ impl Client {
|
|||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
|
@ -54,6 +57,7 @@ impl Client {
|
|||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest {});
|
||||
|
@ -64,6 +68,10 @@ impl Client {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
||||
|
@ -76,6 +84,10 @@ impl Client {
|
|||
Ok((response.generated_texts, response.batch))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batch
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate_with_cache(
|
||||
&mut self,
|
||||
|
@ -90,34 +102,4 @@ impl Client {
|
|||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate_until_finished(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
|
||||
let response = self
|
||||
.stub
|
||||
.generate_until_finished(request)
|
||||
.instrument(info_span!("generate_until_finished"))
|
||||
.await?
|
||||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn generate_until_finished_with_cache(
|
||||
&mut self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
|
||||
let response = self
|
||||
.stub
|
||||
.generate_until_finished_with_cache(request)
|
||||
.instrument(info_span!("generate_until_finished_with_cache"))
|
||||
.await?
|
||||
.into_inner();
|
||||
Ok((response.generated_texts, response.batch))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
//! BLOOM Inference gRPC client library
|
||||
//! Text Generation gRPC client library
|
||||
|
||||
mod client;
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
mod sharded_client;
|
||||
|
||||
|
@ -8,7 +9,7 @@ pub use client::Client;
|
|||
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
pub use tonic::transport;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
|
@ -21,7 +22,7 @@ pub enum ClientError {
|
|||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
Self::Generation(err.to_string())
|
||||
Self::Generation(err.message().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, GeneratedText};
|
||||
use futures::future::join_all;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tonic::transport::Uri;
|
||||
|
||||
/// List of all available commands that can be sent through the command channel
|
||||
#[derive(Clone, Debug)]
|
||||
enum Command {
|
||||
Generate(
|
||||
|
@ -14,36 +16,32 @@ enum Command {
|
|||
Vec<Batch>,
|
||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
||||
),
|
||||
GenerateUntilFinished(
|
||||
Batch,
|
||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
||||
),
|
||||
GenerateUntilFinishedWithCache(
|
||||
Vec<Batch>,
|
||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
||||
),
|
||||
ClearCache(mpsc::Sender<Result<()>>),
|
||||
}
|
||||
|
||||
/// Tokio task that handles the communication with a single shard
|
||||
///
|
||||
/// We subscribe on a broadcast channel to receive commands that will be sent by
|
||||
/// the ShardedClient.
|
||||
///
|
||||
/// Each command is fan out to all shards.
|
||||
///
|
||||
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
|
||||
/// producer = the shards, single consumer = the ShardedClient).
|
||||
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
|
||||
while let Ok(message) = request_subscriber.recv().await {
|
||||
match message {
|
||||
Command::Generate(batch, response_tx) => {
|
||||
let result = client.generate(batch).await;
|
||||
// We can unwrap_or(()) here because the only error that can happen is if the
|
||||
// receiver is dropped, which means that the ShardedClient already received a
|
||||
// response from another shard
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
Command::GenerateWithCache(batches, response_tx) => {
|
||||
let result = client.generate_with_cache(batches).await;
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
Command::GenerateUntilFinished(batch, response_tx) => {
|
||||
let result = client.generate_until_finished(batch).await;
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
|
||||
let result = client.generate_until_finished_with_cache(batches).await;
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
}
|
||||
Command::ClearCache(response_tx) => {
|
||||
let result = client.clear_cache().await;
|
||||
response_tx.try_send(result).unwrap_or(());
|
||||
|
@ -52,30 +50,42 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece
|
|||
}
|
||||
}
|
||||
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
_clients: Vec<Client>,
|
||||
request_tx: broadcast::Sender<Command>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(mut clients: Vec<Client>) -> Self {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
// The broadcast channel to communicate with the shards
|
||||
// We use a capacity of one as the shards are not asynchronous and can only process one
|
||||
// command at a time
|
||||
let (request_tx, _) = broadcast::channel(1);
|
||||
|
||||
for client in clients.drain(..) {
|
||||
// Spawn client tasks
|
||||
for client in clients.iter() {
|
||||
let request_subscriber = request_tx.subscribe();
|
||||
tokio::spawn(client_task(client, request_subscriber));
|
||||
tokio::spawn(client_task(client.clone(), request_subscriber));
|
||||
}
|
||||
|
||||
Self { request_tx }
|
||||
Self {
|
||||
_clients: clients,
|
||||
request_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await.unwrap();
|
||||
let futures = uris.into_iter().map(|path| Client::connect_uds(path));
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given url
|
||||
/// Returns a client connected to the given uri
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
|
@ -87,51 +97,43 @@ impl ShardedClient {
|
|||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
// Create a channel to receive the response from the shards
|
||||
// We will only ever receive one message on this channel
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::Generate(batch, response_tx))
|
||||
.unwrap();
|
||||
// As soon as we receive one response, we can return as all shards will return the same
|
||||
response_rx.recv().await.unwrap()
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batch
|
||||
///
|
||||
/// Returns a list of generated texts of request that met their stopping criteria
|
||||
/// and the next cached batch
|
||||
pub async fn generate_with_cache(
|
||||
&self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
// Create a channel to receive the response from the shards
|
||||
// We will only ever receive one message on this channel
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::GenerateWithCache(batches, response_tx))
|
||||
.unwrap();
|
||||
// As soon as we receive one response, we can return as all shards will return the same
|
||||
response_rx.recv().await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn generate_until_finished(
|
||||
&self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::GenerateUntilFinished(batch, response_tx))
|
||||
.unwrap();
|
||||
response_rx.recv().await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn generate_until_finished_with_cache(
|
||||
&self,
|
||||
batches: Vec<Batch>,
|
||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::GenerateUntilFinishedWithCache(
|
||||
batches,
|
||||
response_tx,
|
||||
))
|
||||
.unwrap();
|
||||
response_rx.recv().await.unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
pub async fn clear_cache(&self) -> Result<()> {
|
||||
// Create a channel to receive the response from the shards
|
||||
// We will only ever receive one message on this channel
|
||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||
self.request_tx
|
||||
.send(Command::ClearCache(response_tx))
|
||||
|
|
|
@ -1,129 +1,158 @@
|
|||
use crate::server::GenerateRequest;
|
||||
/// Batching and inference logic
|
||||
use crate::GenerateRequest;
|
||||
use crate::{Db, Entry};
|
||||
use axum::http::StatusCode;
|
||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
|
||||
const MAX_LENGTH: usize = 128;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InferError {
|
||||
#[error("Request failed during generation: {0}")]
|
||||
GenerationError(String),
|
||||
#[error("Model is overloaded")]
|
||||
Overloaded,
|
||||
}
|
||||
|
||||
impl From<InferError> for (StatusCode, String) {
|
||||
fn from(err: InferError) -> Self {
|
||||
match err {
|
||||
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
||||
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Batcher
|
||||
#[derive(Clone)]
|
||||
pub struct Batcher {
|
||||
/// Request database
|
||||
db: Db,
|
||||
/// Shared state
|
||||
shared: Arc<Shared>,
|
||||
}
|
||||
|
||||
/// Batcher shared state
|
||||
struct Shared {
|
||||
/// Batching background Tokio task notifier
|
||||
batching_task: Notify,
|
||||
}
|
||||
|
||||
impl Batcher {
|
||||
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
max_batch_size: usize,
|
||||
max_waiting_time: Duration,
|
||||
) -> Self {
|
||||
// Batcher shared state
|
||||
let db = Db::new();
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
||||
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
client,
|
||||
db.clone(),
|
||||
shared.clone(),
|
||||
));
|
||||
|
||||
Self { db, shared }
|
||||
}
|
||||
|
||||
/// Add a new request to the database and return a future that will generate the text
|
||||
pub(crate) async fn infer(
|
||||
&self,
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
) -> Result<String, InferError> {
|
||||
if self.db.len() > MAX_LENGTH {
|
||||
return Err(InferError::Overloaded);
|
||||
}
|
||||
let (request_tx, request_rx) = oneshot::channel();
|
||||
// One shot channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
|
||||
// Try to append the request to the database
|
||||
self.db.append(Entry {
|
||||
request,
|
||||
response_tx: request_tx,
|
||||
response_tx,
|
||||
input_length,
|
||||
time: Instant::now(),
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the database that needs
|
||||
// to be batched
|
||||
self.shared.batching_task.notify_waiters();
|
||||
match request_rx.await.unwrap() {
|
||||
|
||||
// Await on the response from the background task
|
||||
// We can safely unwrap as the background task will never drop the sender
|
||||
match response_rx.await.unwrap() {
|
||||
Ok(output) => Ok(output),
|
||||
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn batching_task(max_batch_size: usize,
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[instrument(skip(client, db, shared))]
|
||||
async fn batching_task(
|
||||
max_batch_size: usize,
|
||||
max_waiting_time: Duration,
|
||||
client: ShardedClient,
|
||||
db: Db,
|
||||
shared: Arc<Shared>) {
|
||||
shared: Arc<Shared>,
|
||||
) {
|
||||
// Minimum batch size after which we try to add more requests
|
||||
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
||||
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Batcher struct
|
||||
shared.batching_task.notified().await;
|
||||
|
||||
if let Some(batch) = db.next_batch(max_batch_size) {
|
||||
let request_ids = batch.requests.iter().map(|req| req.id).collect();
|
||||
let mut cached_batch = match batch.size {
|
||||
size if size > limit_min_batch_size => {
|
||||
wrap_future(client.generate_until_finished(batch), request_ids, &db).await
|
||||
}
|
||||
_ => wrap_future(client.generate(batch), request_ids, &db).await,
|
||||
};
|
||||
// Get the next batch from the DB
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the DB
|
||||
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
|
||||
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
let mut current_batch_size = batch.size;
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
||||
let mut batches = vec![batch];
|
||||
|
||||
if current_batch_size <= limit_min_batch_size {
|
||||
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
|
||||
let new_batch_request_ids =
|
||||
new_batch.requests.iter().map(|req| req.id).collect();
|
||||
// If the current batch is too small, we try to add more requests to it
|
||||
if batch_size <= limit_min_batch_size {
|
||||
// Get the next batch from the DB that meet our minimum size criteria
|
||||
if let Some((new_request_ids, new_batch)) =
|
||||
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
|
||||
{
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
|
||||
.await;
|
||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
// If we don't have enough requests to meet the minimum size criteria, we
|
||||
// try to get the next batch from the DB that have been waiting over
|
||||
// the max_waiting_time
|
||||
else if let Some((new_request_ids, new_batch)) =
|
||||
db.next_batch(None, max_batch_size, Some(max_waiting_time))
|
||||
{
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
current_batch_size += new_cached_batch.size;
|
||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cached_batch = match current_batch_size {
|
||||
size if size > limit_min_batch_size => {
|
||||
wrap_future(
|
||||
client.generate_until_finished_with_cache(batches),
|
||||
request_ids,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
|
||||
};
|
||||
cached_batch =
|
||||
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||
async fn wrap_future(
|
||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||
request_ids: Vec<u64>,
|
||||
|
@ -134,6 +163,7 @@ async fn wrap_future(
|
|||
send_generated(generated_texts, db);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
send_error(err, request_ids, db);
|
||||
None
|
||||
|
@ -141,16 +171,20 @@ async fn wrap_future(
|
|||
}
|
||||
}
|
||||
|
||||
/// Send errors to the Batcher for all `request_ids`
|
||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
||||
request_ids.into_iter().for_each(|id| {
|
||||
// We can `expect` here as the request id should always be in the DB
|
||||
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
/// Send `generated_text` to the Batcher for all `finished`
|
||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||
finished.into_iter().for_each(|output| {
|
||||
// We can `expect` here as the request id should always be in the DB
|
||||
let entry = db
|
||||
.remove(&output.request.unwrap().id)
|
||||
.expect("ID not found in db. This is a bug.");
|
||||
|
@ -158,3 +192,18 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
|||
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InferError {
|
||||
#[error("Request failed during generation: {0}")]
|
||||
GenerationError(String),
|
||||
}
|
||||
|
||||
/// Convert to Axum supported format
|
||||
impl From<InferError> for (StatusCode, String) {
|
||||
fn from(err: InferError) -> Self {
|
||||
match err {
|
||||
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
287
router/src/db.rs
287
router/src/db.rs
|
@ -1,16 +1,173 @@
|
|||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::server::{GenerateParameters, GenerateRequest};
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use parking_lot::RwLock;
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
|
||||
/// Database entry
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Entry {
|
||||
/// Request
|
||||
pub request: GenerateRequest,
|
||||
/// Response sender to communicate between the Batcher and the batching_task
|
||||
pub response_tx: Sender<Result<String, ClientError>>,
|
||||
/// Number of tokens in the input
|
||||
pub input_length: usize,
|
||||
/// Instant when this entry was created
|
||||
pub time: Instant,
|
||||
}
|
||||
|
||||
/// Request Database
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Db {
|
||||
pub shared: Arc<Shared>,
|
||||
}
|
||||
|
||||
/// Shared state
|
||||
#[derive(Debug)]
|
||||
pub struct Shared {
|
||||
state: Mutex<State>,
|
||||
}
|
||||
|
||||
/// Database State
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
/// Database entries organized in a BTreeMap to be able to iterate over them in order
|
||||
entries: BTreeMap<u64, Entry>,
|
||||
|
||||
/// Id of the next entry
|
||||
next_id: u64,
|
||||
|
||||
/// Id of the next batch
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
|
||||
next_batch_start_id: u64,
|
||||
}
|
||||
|
||||
impl State {
|
||||
/// Get the next requests
|
||||
fn next_requests(
|
||||
&self,
|
||||
max_size: usize,
|
||||
min_waiting_time: Option<Duration>,
|
||||
) -> Option<(Vec<u64>, Vec<Request>)> {
|
||||
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
||||
let mut requests = Vec::new();
|
||||
let mut ids = Vec::new();
|
||||
|
||||
for (id, entry) in self
|
||||
.entries
|
||||
// Start from next_batch_start_id
|
||||
.range(self.next_batch_start_id..)
|
||||
// Take max_size
|
||||
.take(max_size)
|
||||
{
|
||||
if let Some(min_waiting_time) = min_waiting_time {
|
||||
// Only take entries that waited for at least min_waiting_time
|
||||
if entry.time.elapsed() < min_waiting_time {
|
||||
// Since entries are ordered, we already know that all following entries won't
|
||||
// satisfy the condition
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: *id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_length: entry.input_length as u32,
|
||||
parameters: Some(LogitsWarperParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
max_new_tokens: entry.request.parameters.max_new_tokens,
|
||||
});
|
||||
|
||||
ids.push(*id);
|
||||
}
|
||||
|
||||
if requests.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some((ids, requests))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Db {
|
||||
pub(crate) fn new() -> Self {
|
||||
// Shared state
|
||||
let shared = Arc::new(Shared {
|
||||
state: Mutex::new(State {
|
||||
entries: BTreeMap::new(),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
next_batch_start_id: 0,
|
||||
}),
|
||||
});
|
||||
|
||||
Self { shared }
|
||||
}
|
||||
|
||||
/// Append an entry to the database
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
// Acquire lock
|
||||
let mut state = self.shared.state.lock();
|
||||
|
||||
// Insert entry
|
||||
let id = state.next_id;
|
||||
state.next_id += 1;
|
||||
state.entries.insert(id, entry);
|
||||
}
|
||||
|
||||
/// Remove an entry from the database if it exists
|
||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
||||
let mut state = self.shared.state.lock();
|
||||
state.entries.remove(id)
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
pub(crate) fn next_batch(
|
||||
&self,
|
||||
min_size: Option<usize>,
|
||||
max_size: usize,
|
||||
min_waiting_time: Option<Duration>,
|
||||
) -> Option<(Vec<u64>, Batch)> {
|
||||
// Acquire lock
|
||||
let mut state = self.shared.state.lock();
|
||||
|
||||
// Get requests from the database
|
||||
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
|
||||
if let Some(min_size) = min_size {
|
||||
// If min_size is set, only return a batch if there are enough requests
|
||||
if requests.len() < min_size {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Batch size
|
||||
let size = requests.len();
|
||||
// Longest input length for all requests in batch size
|
||||
// Used for padding inside the inference server
|
||||
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
||||
let batch = Batch {
|
||||
id: state.next_batch_id,
|
||||
requests,
|
||||
size: size as u32,
|
||||
max_sequence_length,
|
||||
};
|
||||
// Update next_batch_start_id to the last id in the batch + 1
|
||||
state.next_batch_start_id = ids.last().unwrap() + 1;
|
||||
// Increment batch id
|
||||
state.next_batch_id += 1;
|
||||
|
||||
return Some((ids, batch));
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GenerateParameters> for LogitsWarperParameters {
|
||||
|
@ -23,129 +180,3 @@ impl From<GenerateParameters> for LogitsWarperParameters {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Db {
|
||||
pub shared: Arc<Shared>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Shared {
|
||||
state: RwLock<State>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
entries: BTreeMap<u64, Entry>,
|
||||
|
||||
/// Identifier to use for the next expiration. Each expiration is associated
|
||||
/// with a unique identifier. See above for why.
|
||||
next_id: u64,
|
||||
|
||||
next_batch_id: u64,
|
||||
|
||||
/// Current batch id
|
||||
next_batch_start_id: u64,
|
||||
}
|
||||
|
||||
impl Db {
|
||||
pub(crate) fn new() -> Self {
|
||||
let shared = Arc::new(Shared {
|
||||
state: RwLock::new(State {
|
||||
entries: BTreeMap::new(),
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
next_batch_start_id: 0,
|
||||
}),
|
||||
});
|
||||
|
||||
Self { shared }
|
||||
}
|
||||
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
let mut state = self.shared.state.write();
|
||||
|
||||
let id = state.next_id;
|
||||
state.next_id += 1;
|
||||
|
||||
state.entries.insert(id, entry);
|
||||
}
|
||||
|
||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
||||
let mut state = self.shared.state.write();
|
||||
state.entries.remove(id)
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
let state = self.shared.state.read();
|
||||
state.entries.len()
|
||||
}
|
||||
|
||||
fn next_requests(&self, max_size: usize) -> Option<(u64, Vec<Request>)> {
|
||||
let state = self.shared.state.read();
|
||||
|
||||
let requests: Vec<Request> = state
|
||||
.entries
|
||||
.range(state.next_batch_start_id..)
|
||||
.take(max_size)
|
||||
.map(|(id, entry)| Request {
|
||||
id: *id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_length: entry.input_length as u32,
|
||||
parameters: Some(LogitsWarperParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
max_new_tokens: entry.request.parameters.max_new_tokens,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if requests.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let last_id = requests.last().unwrap().id;
|
||||
Some((last_id, requests))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn next_batch(&self, max_size: usize) -> Option<Batch> {
|
||||
if let Some((last_id, requests)) = self.next_requests(max_size) {
|
||||
let mut state = self.shared.state.write();
|
||||
let size = requests.len();
|
||||
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
||||
let batch = Batch {
|
||||
id: state.next_batch_id,
|
||||
requests,
|
||||
size: size as u32,
|
||||
max_sequence_length,
|
||||
};
|
||||
state.next_batch_start_id = last_id + 1;
|
||||
state.next_batch_id += 1;
|
||||
return Some(batch);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(crate) fn next_batch_minimum_size(
|
||||
&self,
|
||||
min_size: usize,
|
||||
max_size: usize,
|
||||
) -> Option<Batch> {
|
||||
if let Some((last_id, requests)) = self.next_requests(max_size) {
|
||||
if requests.len() >= min_size {
|
||||
let mut state = self.shared.state.write();
|
||||
let size = requests.len();
|
||||
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
||||
let batch = Batch {
|
||||
id: state.next_batch_id,
|
||||
requests,
|
||||
size: size as u32,
|
||||
max_sequence_length,
|
||||
};
|
||||
state.next_batch_start_id = last_id + 1;
|
||||
state.next_batch_id += 1;
|
||||
return Some(batch);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,68 @@
|
|||
/// Text Generation Inference Webserver
|
||||
mod batcher;
|
||||
mod db;
|
||||
mod validation;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
use db::{Db, Entry};
|
||||
use batcher::Batcher;
|
||||
use db::{Db, Entry};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validation::Validation;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: i32,
|
||||
#[serde(default = "default_top_p")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "default_do_sample")]
|
||||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
pub max_new_tokens: u32,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn default_top_p() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_do_sample() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
temperature: default_temperature(),
|
||||
top_k: default_top_k(),
|
||||
top_p: default_top_p(),
|
||||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct GenerateRequest {
|
||||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub generated_text: String,
|
||||
}
|
||||
|
||||
pub(crate) type GenerateResponse = Vec<GeneratedText>;
|
||||
|
|
|
@ -1,37 +1,61 @@
|
|||
/// Text Generation Inference webserver entrypoint
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use clap::Parser;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
use clap::Parser;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "32", long, short, env)]
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "1000", long, env)]
|
||||
max_input_length: usize,
|
||||
#[clap(default_value = "32", long, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_waiting_time: u64,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
||||
shard_uds_path: String,
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_input_length,
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
port,
|
||||
shard_uds_path,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
validation_workers,
|
||||
} = args;
|
||||
|
||||
if validation_workers == 1 {
|
||||
panic!("validation_workers must be > 0");
|
||||
}
|
||||
|
||||
let max_waiting_time = Duration::from_secs(max_waiting_time);
|
||||
|
||||
// Download and instantiate tokenizer
|
||||
// This will only be used to validate payloads
|
||||
//
|
||||
// We need to download it outside of the Tokio runtime
|
||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
||||
|
||||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
|
@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.block_on(async {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let sharded_client = ShardedClient::connect_uds(shard_uds_path)
|
||||
// Instantiate sharded client from the master unix socket
|
||||
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache()
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
// Binds on localhost
|
||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
||||
|
||||
server::run(max_batch_size, sharded_client, tokenizer, addr).await;
|
||||
// Run server
|
||||
server::run(
|
||||
max_concurrent_requests,
|
||||
max_input_length,
|
||||
max_batch_size,
|
||||
max_waiting_time,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
validation_workers,
|
||||
addr,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,68 +1,44 @@
|
|||
use crate::{Batcher, Validation};
|
||||
use crate::{
|
||||
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::StatusCode;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::Instant;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct GenerateParameters {
|
||||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: i32,
|
||||
#[serde(default = "default_top_p")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "default_do_sample")]
|
||||
pub do_sample: bool,
|
||||
#[serde(default = "default_max_new_tokens")]
|
||||
pub max_new_tokens: u32,
|
||||
}
|
||||
|
||||
fn default_temperature() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn default_top_p() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_do_sample() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> u32 {
|
||||
20
|
||||
}
|
||||
|
||||
fn default_parameters() -> GenerateParameters {
|
||||
GenerateParameters {
|
||||
temperature: default_temperature(),
|
||||
top_k: default_top_k(),
|
||||
top_p: default_top_p(),
|
||||
do_sample: default_do_sample(),
|
||||
max_new_tokens: default_max_new_tokens(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub(crate) struct GenerateRequest {
|
||||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
// Server shared state
|
||||
#[derive(Clone)]
|
||||
struct ServerState {
|
||||
validation: Validation,
|
||||
batcher: Batcher,
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
/// Health check method
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
||||
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead if check if the gRPC channels are still healthy.
|
||||
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||
(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Model is overloaded".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Send a small inference request
|
||||
state
|
||||
.batcher
|
||||
.infer(
|
||||
|
@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate method
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn generate(
|
||||
state: Extension<ServerState>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
|
||||
let start = Instant::now();
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||
(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
"Model is overloaded".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Validate request
|
||||
let (input_length, validated_request) = state
|
||||
.validation
|
||||
// FIXME: can't we get rid of the cloning here??
|
||||
.validate(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
// Inference
|
||||
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
||||
|
||||
// Tracing metadata
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record(
|
||||
"time_per_token",
|
||||
|
@ -106,31 +94,71 @@ async fn generate(
|
|||
);
|
||||
tracing::info!("response: {}", generated_text);
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"generated_text": generated_text,
|
||||
})))
|
||||
// Send response
|
||||
let response = vec![GeneratedText { generated_text }];
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ServerState {
|
||||
validation: Validation,
|
||||
batcher: Batcher,
|
||||
}
|
||||
|
||||
pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
|
||||
let batcher = Batcher::new(client, max_batch_size);
|
||||
let validation = Validation::new(tokenizer);
|
||||
|
||||
let shared_state = ServerState { validation, batcher };
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
max_concurrent_requests: usize,
|
||||
max_input_length: usize,
|
||||
max_batch_size: usize,
|
||||
max_waiting_time: Duration,
|
||||
client: ShardedClient,
|
||||
tokenizer: Tokenizer,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
// Create state
|
||||
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
|
||||
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||
let shared_state = ServerState {
|
||||
validation,
|
||||
batcher,
|
||||
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
||||
};
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
.layer(Extension(shared_state.clone()))
|
||||
.route("/health", get(liveness))
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(shared_state.clone()));
|
||||
|
||||
// Run server
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
// Wait until all requests are finished to shut down
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Shutdown signal handler
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
|
||||
tracing::info!("signal received, starting graceful shutdown");
|
||||
}
|
||||
|
|
|
@ -1,62 +1,105 @@
|
|||
use crate::server::GenerateRequest;
|
||||
/// Payload validation logic
|
||||
use crate::GenerateRequest;
|
||||
use axum::http::StatusCode;
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::{
|
||||
DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
|
||||
TokenizerImpl,
|
||||
};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("Temperature must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("Top p must be <= 0.0 or > 1.0")]
|
||||
TopP,
|
||||
#[error("Top k must be strictly positive")]
|
||||
TopK,
|
||||
#[error("Max New Tokens must be < 512")]
|
||||
MaxNewTokens,
|
||||
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
|
||||
InputLength(usize),
|
||||
}
|
||||
|
||||
impl From<ValidationError> for (StatusCode, String) {
|
||||
fn from(err: ValidationError) -> Self {
|
||||
(StatusCode::BAD_REQUEST, err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
type ValidationRequest = (
|
||||
GenerateRequest,
|
||||
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
|
||||
);
|
||||
|
||||
/// Validation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
/// Channel to communicate with the background validation task
|
||||
sender: mpsc::Sender<ValidationRequest>,
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
pub(crate) fn new(tokenizer: Tokenizer) -> Self {
|
||||
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
|
||||
// Crate channel
|
||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
||||
|
||||
tokio::spawn(validation_task(tokenizer, validation_receiver));
|
||||
// Launch background validation task
|
||||
tokio::spawn(validation_task(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_input_length,
|
||||
validation_receiver,
|
||||
));
|
||||
|
||||
Self {
|
||||
sender: validation_sender,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a payload and get the number of tokens in the input
|
||||
pub(crate) async fn validate(
|
||||
&self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<(usize, GenerateRequest), ValidationError> {
|
||||
// Create response channel
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
self.sender.send((request, sender)).await.unwrap();
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
receiver.await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
|
||||
while let Some((request, response_tx)) = receiver.recv().await {
|
||||
/// Validation task
|
||||
/// Load balance the validation requests between multiple validation workers
|
||||
async fn validation_task(
|
||||
workers: usize,
|
||||
tokenizer: Tokenizer,
|
||||
max_input_length: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
let mut workers_senders = Vec::with_capacity(workers);
|
||||
|
||||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
// Create channel to communicate with worker
|
||||
let (worker_sender, worker_receiver) = mpsc::channel(workers);
|
||||
workers_senders.push(worker_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
|
||||
});
|
||||
}
|
||||
|
||||
loop {
|
||||
// Load balance requests between workers
|
||||
for sender in workers_senders.iter() {
|
||||
if let Some(validation_request) = receiver.recv().await {
|
||||
sender.send(validation_request).await.unwrap();
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check the parameters inside the payload and get the number of tokens inside the input using
|
||||
/// the tokenizer
|
||||
fn validation_worker(
|
||||
tokenizer: TokenizerImpl<
|
||||
ModelWrapper,
|
||||
NormalizerWrapper,
|
||||
PreTokenizerWrapper,
|
||||
PostProcessorWrapper,
|
||||
DecoderWrapper,
|
||||
>,
|
||||
max_input_length: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
while let Some((request, response_tx)) = receiver.blocking_recv() {
|
||||
if request.parameters.temperature < 0.0 {
|
||||
response_tx
|
||||
.send(Err(ValidationError::Temperature))
|
||||
|
@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
|||
continue;
|
||||
}
|
||||
|
||||
// Get the number of tokens in the input
|
||||
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
|
||||
let input_length = inputs.len();
|
||||
|
||||
if input_length > 1000 {
|
||||
if input_length > max_input_length {
|
||||
response_tx
|
||||
.send(Err(ValidationError::InputLength(input_length)))
|
||||
.unwrap_or(());
|
||||
|
@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
|||
response_tx.send(Ok((input_length, request))).unwrap_or(());
|
||||
}
|
||||
}
|
||||
|
||||
type ValidationRequest = (
|
||||
GenerateRequest,
|
||||
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
|
||||
);
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("Temperature must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("Top p must be <= 0.0 or > 1.0")]
|
||||
TopP,
|
||||
#[error("Top k must be strictly positive")]
|
||||
TopK,
|
||||
#[error("Max New Tokens must be < 512")]
|
||||
MaxNewTokens,
|
||||
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
|
||||
InputLength(usize),
|
||||
}
|
||||
|
||||
impl From<ValidationError> for (StatusCode, String) {
|
||||
fn from(err: ValidationError) -> Self {
|
||||
(StatusCode::BAD_REQUEST, err.to_string())
|
||||
}
|
||||
}
|
||||
|
|
30
run.sh
30
run.sh
|
@ -1,30 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
||||
|
||||
# Run in background
|
||||
$server_cmd 2>&1 > /dev/null &
|
||||
|
||||
# Check if server is running by checking if the unix socket is created
|
||||
FILE=/tmp/bloom-inference-0
|
||||
while :
|
||||
do
|
||||
if test -S "$FILE"; then
|
||||
echo "Text Generation Python gRPC server started"
|
||||
break
|
||||
else
|
||||
echo "Waiting for Text Generation Python gRPC server to start"
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
|
||||
sleep 1
|
||||
|
||||
# Run in background
|
||||
text-generation-router &
|
||||
|
||||
# Wait for any process to exit
|
||||
wait -n
|
||||
|
||||
# Exit with status of process that exited first
|
||||
exit $?
|
|
@ -0,0 +1,155 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
bloom_inference/__pycache__/
|
||||
bloom_inference/pb/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
|
@ -4,17 +4,28 @@ gen-server:
|
|||
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch bloom_inference/pb/__init__.py
|
||||
|
||||
unit-tests:
|
||||
python -m pytest --cov=bloom_inference tests
|
||||
install-transformers:
|
||||
# Install specific version of transformers
|
||||
rm transformers || true
|
||||
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
|
||||
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
|
||||
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
|
||||
mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers
|
||||
cd transformers && python setup.py install
|
||||
|
||||
unit-tests-reporting:
|
||||
python -m pytest --junitxml=report.xml --cov=bloom_inference tests
|
||||
install-torch:
|
||||
# Install specific version of torch
|
||||
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
||||
|
||||
pip-install:
|
||||
pip install grpcio-tools
|
||||
make gen-server
|
||||
make install-torch
|
||||
make install-transformers
|
||||
pip install .
|
||||
|
||||
install:
|
||||
poetry install
|
||||
make gen-server
|
||||
make install-torch
|
||||
make install-transformers
|
||||
|
|
|
@ -1,41 +1,51 @@
|
|||
import os
|
||||
import typer
|
||||
|
||||
from pathlib import Path
|
||||
from torch.distributed.launcher import launch_agent, LaunchConfig
|
||||
from typing import Optional
|
||||
|
||||
from bloom_inference import server
|
||||
from bloom_inference import prepare_weights, server
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def launcher(
|
||||
model_name: str,
|
||||
num_gpus: int = 1,
|
||||
shard_directory: Optional[Path] = None,
|
||||
):
|
||||
if num_gpus == 1:
|
||||
serve(model_name, False, shard_directory)
|
||||
|
||||
else:
|
||||
config = LaunchConfig(
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
nproc_per_node=num_gpus,
|
||||
rdzv_backend="c10d",
|
||||
max_restarts=0,
|
||||
)
|
||||
launch_agent(config, server.serve, [model_name, True, shard_directory])
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_name: str,
|
||||
sharded: bool = False,
|
||||
shard_directory: Optional[Path] = None,
|
||||
uds_path: Path = "/tmp/bloom-inference",
|
||||
):
|
||||
server.serve(model_name, sharded, shard_directory)
|
||||
if sharded:
|
||||
assert (
|
||||
shard_directory is not None
|
||||
), "shard_directory must be set when sharded is True"
|
||||
assert (
|
||||
os.getenv("RANK", None) is not None
|
||||
), "RANK must be set when sharded is True"
|
||||
assert (
|
||||
os.getenv("WORLD_SIZE", None) is not None
|
||||
), "WORLD_SIZE must be set when sharded is True"
|
||||
assert (
|
||||
os.getenv("MASTER_ADDR", None) is not None
|
||||
), "MASTER_ADDR must be set when sharded is True"
|
||||
assert (
|
||||
os.getenv("MASTER_PORT", None) is not None
|
||||
), "MASTER_PORT must be set when sharded is True"
|
||||
|
||||
server.serve(model_name, sharded, uds_path, shard_directory)
|
||||
|
||||
|
||||
@app.command()
|
||||
def prepare_weights(
|
||||
model_name: str,
|
||||
shard_directory: Path,
|
||||
cache_directory: Path,
|
||||
num_shard: int = 1,
|
||||
):
|
||||
prepare_weights.prepare_weights(
|
||||
model_name, cache_directory, shard_directory, num_shard
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -24,6 +24,7 @@ torch.manual_seed(0)
|
|||
class Batch:
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
all_input_lengths: List[int]
|
||||
input_ids: Dict[str, torch.Tensor]
|
||||
all_input_ids: List[torch.Tensor]
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
|
@ -46,12 +47,12 @@ class Batch:
|
|||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
input_lengths = []
|
||||
all_input_lengths = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
input_lengths.append(r.input_length)
|
||||
all_input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser(
|
||||
temperature=r.parameters.temperature,
|
||||
|
@ -63,17 +64,12 @@ class Batch:
|
|||
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
||||
|
||||
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
|
||||
# Remove padding from all_input_ids
|
||||
all_input_ids = [
|
||||
input_ids.squeeze(0)[-length:].unsqueeze(-1)
|
||||
for length, input_ids in zip(
|
||||
input_lengths, input_ids["input_ids"].split(1, dim=0)
|
||||
)
|
||||
]
|
||||
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
all_input_lengths=all_input_lengths,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
|
@ -91,6 +87,7 @@ class Batch:
|
|||
# Batch attributes
|
||||
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
||||
requests = []
|
||||
all_input_lengths = []
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
@ -100,6 +97,7 @@ class Batch:
|
|||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
all_input_lengths.extend(batch.all_input_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
@ -198,6 +196,7 @@ class Batch:
|
|||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
all_input_lengths=all_input_lengths,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
|
@ -227,7 +226,10 @@ class BLOOM:
|
|||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype)
|
||||
AutoModelForCausalLM.from_pretrained(model_name)
|
||||
.eval()
|
||||
.to(self.device)
|
||||
.to(dtype)
|
||||
)
|
||||
self.num_heads = self.model.base_model.num_heads
|
||||
|
||||
|
@ -253,6 +255,7 @@ class BLOOM:
|
|||
# New input_ids for next forward
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
next_all_input_lengths = []
|
||||
|
||||
next_batch_size = 0
|
||||
next_batch_max_sequence_length = 0
|
||||
|
@ -263,6 +266,7 @@ class BLOOM:
|
|||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.all_input_lengths,
|
||||
outputs.logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
|
@ -272,6 +276,7 @@ class BLOOM:
|
|||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
|
@ -302,8 +307,10 @@ class BLOOM:
|
|||
next_batch_input_ids.append(next_token)
|
||||
next_batch_all_input_ids.append(all_tokens)
|
||||
next_batch_size += 1
|
||||
new_input_length = input_length + 1
|
||||
next_all_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, len(all_tokens)
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
|
@ -350,6 +357,7 @@ class BLOOM:
|
|||
next_batch = Batch(
|
||||
batch_id=batch.batch_id,
|
||||
requests=next_batch_requests,
|
||||
all_input_lengths=next_all_input_lengths,
|
||||
input_ids=next_batch_input_ids,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
|
@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM):
|
|||
if self.master:
|
||||
# TODO @thomasw21 do some caching
|
||||
shard_state_dict_paths = prepare_weights(
|
||||
model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size
|
||||
model_name,
|
||||
shard_directory / "cache",
|
||||
shard_directory,
|
||||
tp_world_size=self.world_size,
|
||||
)
|
||||
shard_state_dict_paths = [
|
||||
str(path.absolute()) for path in shard_state_dict_paths
|
||||
|
@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||
|
||||
batch_size, vocab_shard_size = logits_shard.shape
|
||||
|
|
|
@ -14,7 +14,7 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
|
|||
|
||||
|
||||
def match_suffix(text, suffix):
|
||||
return text[-len(suffix):] == suffix
|
||||
return text[-len(suffix) :] == suffix
|
||||
|
||||
|
||||
def http_get(
|
||||
|
@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path):
|
|||
return filename
|
||||
|
||||
|
||||
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int):
|
||||
def prepare_weights(
|
||||
model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
|
||||
):
|
||||
save_paths = [
|
||||
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
||||
for tp_rank in range(tp_world_size)
|
||||
|
@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
if model_name == "bigscience/bloom-560m":
|
||||
url = hf_hub_url(model_name, filename="pytorch_model.bin")
|
||||
cache_download_url(url, cache_path)
|
||||
|
||||
elif model_name == "bigscience/bloom":
|
||||
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
|
||||
index_path = cache_download_url(url, cache_path)
|
||||
|
@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
index = json.load(f)
|
||||
|
||||
# Get unique file names
|
||||
weight_files = list(set([filename for filename in index["weight_map"].values()]))
|
||||
weight_files = list(
|
||||
set([filename for filename in index["weight_map"].values()])
|
||||
)
|
||||
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
|
||||
|
||||
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls))
|
||||
Parallel(n_jobs=5)(
|
||||
delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name: {model_name}")
|
||||
|
||||
|
@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
assert len(sharded_weights) == tp_world_size
|
||||
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = shard.detach().clone()
|
||||
|
||||
elif match_suffix(state_name, "lm_head.weight"):
|
||||
output_size = state.shape[0]
|
||||
|
@ -132,7 +141,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
sharded_weights = torch.split(state, block_size, dim=1)
|
||||
assert len(sharded_weights) == tp_world_size
|
||||
for tp_rank, shard in enumerate(sharded_weights):
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = shard.detach().clone()
|
||||
|
||||
elif any(
|
||||
match_suffix(state_name, candidate)
|
||||
|
@ -141,14 +152,20 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
"mlp.dense_4h_to_h.bias",
|
||||
]
|
||||
):
|
||||
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone()
|
||||
shards_state_dicts[0][
|
||||
"transformer." + state_name
|
||||
] = state.detach().clone()
|
||||
for tp_rank in range(1, tp_world_size):
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state)
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = torch.zeros_like(state)
|
||||
|
||||
else:
|
||||
# We duplicate parameters across tp ranks
|
||||
for tp_rank in range(tp_world_size):
|
||||
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone()
|
||||
shards_state_dicts[tp_rank][
|
||||
"transformer." + state_name
|
||||
] = state.detach().clone()
|
||||
|
||||
del state_dict[state_name] # delete key from state_dict
|
||||
del state # delete tensor
|
||||
|
@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
|||
torch.save(shard_state_dict, save_path)
|
||||
|
||||
return save_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument("--model-name", required=True, type=str)
|
||||
parser.add_argument("--cache-path", required=True, type=str)
|
||||
parser.add_argument("--save-path", required=True, type=str)
|
||||
parser.add_argument("--world-size", required=True, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)
|
||||
|
|
|
@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
async def GenerateUntilFinished(self, request, context):
|
||||
batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)
|
||||
|
||||
generated_texts = []
|
||||
while not generated_texts:
|
||||
generated_texts, next_batch = self.model.generate_token(batch)
|
||||
batch = next_batch
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.GenerateUntilFinishedResponse(
|
||||
generated_texts=[
|
||||
generated_text.to_pb() for generated_text in generated_texts
|
||||
],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
async def GenerateUntilFinishedWithCache(self, request, context):
|
||||
if len(request.batches) == 0:
|
||||
raise ValueError("Must provide at least one batch")
|
||||
|
||||
batches = []
|
||||
for batch_pb in request.batches:
|
||||
batch = self.cache.pop(batch_pb.id)
|
||||
if batch is None:
|
||||
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) > 1:
|
||||
batch = Batch.concatenate(batches)
|
||||
else:
|
||||
batch = batches[0]
|
||||
|
||||
generated_texts = []
|
||||
while not generated_texts:
|
||||
generated_texts, next_batch = self.model.generate_token(batch)
|
||||
batch = next_batch
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.GenerateUntilFinishedWithCacheResponse(
|
||||
generated_texts=[
|
||||
generated_text.to_pb() for generated_text in generated_texts
|
||||
],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
)
|
||||
|
||||
|
||||
def serve(model_name, sharded, shard_directory):
|
||||
def serve(
|
||||
model_name: str,
|
||||
sharded: bool,
|
||||
uds_path: Path,
|
||||
shard_directory: Optional[Path] = None,
|
||||
):
|
||||
async def serve_inner(
|
||||
model_name: str,
|
||||
sharded: bool = False,
|
||||
shard_directory: Optional[Path] = None,
|
||||
):
|
||||
unix_socket_template = "unix:///tmp/bloom-inference-{}"
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
if shard_directory is None:
|
||||
raise ValueError("shard_directory must be set when sharded is True")
|
||||
model = BLOOMSharded(model_name, shard_directory)
|
||||
server_urls = [
|
||||
unix_socket_template.format(rank) for rank in range(model.world_size)
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
for rank in range(model.world_size)
|
||||
]
|
||||
local_url = unix_socket_template.format(model.rank)
|
||||
local_url = server_urls[model.rank]
|
||||
else:
|
||||
model = BLOOM(model_name)
|
||||
local_url = unix_socket_template.format(0)
|
||||
local_url = unix_socket_template.format(uds_path, 0)
|
||||
server_urls = [local_url]
|
||||
|
||||
server = aio.server()
|
||||
|
@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory):
|
|||
server.add_insecure_port(local_url)
|
||||
await server.start()
|
||||
print("Server started at {}".format(local_url))
|
||||
try:
|
||||
await server.wait_for_termination()
|
||||
except KeyboardInterrupt:
|
||||
print("Signal received. Shutting down")
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(serve_inner(model_name, sharded, shard_directory))
|
||||
|
|
|
@ -82,7 +82,6 @@ def initialize_torch_distributed():
|
|||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timedelta(seconds=60),
|
||||
init_method="tcp://localhost:6000",
|
||||
)
|
||||
|
||||
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
||||
|
|
|
@ -205,7 +205,7 @@ python-versions = ">=3.7"
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2"
|
||||
content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496"
|
||||
|
||||
[metadata.files]
|
||||
accelerate = [
|
||||
|
|
|
@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app'
|
|||
python = "^3.9"
|
||||
protobuf = "^4.21.7"
|
||||
grpcio = "^1.49.1"
|
||||
torch = "^1.12.1"
|
||||
typer = "^0.6.1"
|
||||
grpcio-reflection = "^1.49.1"
|
||||
accelerate = "^0.12.0"
|
||||
|
|
Loading…
Reference in New Issue