v0.1.0
This commit is contained in:
parent
92c1ecd008
commit
f16f2f5ae1
|
@ -1,2 +1,2 @@
|
||||||
aml
|
aml
|
||||||
router/target
|
target
|
|
@ -1 +1,2 @@
|
||||||
.idea
|
.idea
|
||||||
|
target
|
|
@ -55,9 +55,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.57"
|
version = "0.1.58"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f"
|
checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -166,9 +166,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.11.0"
|
version = "3.11.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d"
|
checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "byteorder"
|
name = "byteorder"
|
||||||
|
@ -255,9 +255,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.0.15"
|
version = "4.0.17"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
|
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atty",
|
"atty",
|
||||||
"bitflags",
|
"bitflags",
|
||||||
|
@ -391,6 +391,16 @@ dependencies = [
|
||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ctrlc"
|
||||||
|
version = "3.2.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1d91974fbbe88ec1df0c24a4f00f99583667a7e2e6272b2b92d294d81e462173"
|
||||||
|
dependencies = [
|
||||||
|
"nix",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
|
@ -529,7 +539,7 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
"windows-sys",
|
"windows-sys 0.36.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -936,9 +946,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
version = "1.0.3"
|
version = "1.0.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6c8af84674fe1f223a982c933a0ee1086ac4d4052aa0fb8060c12c6ad838e754"
|
checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "js-sys"
|
name = "js-sys"
|
||||||
|
@ -957,9 +967,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.134"
|
version = "0.2.135"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
|
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lock_api"
|
name = "lock_api"
|
||||||
|
@ -1047,7 +1057,7 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
"windows-sys",
|
"windows-sys 0.36.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1074,6 +1084,18 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nix"
|
||||||
|
version = "0.25.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e322c04a9e3440c327fca7b6c8a63e6890a32fa2ad689db972425f07e0d22abb"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"bitflags",
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.1"
|
version = "7.1.1"
|
||||||
|
@ -1084,6 +1106,16 @@ dependencies = [
|
||||||
"minimal-lexical",
|
"minimal-lexical",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nu-ansi-term"
|
||||||
|
version = "0.46.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
|
||||||
|
dependencies = [
|
||||||
|
"overload",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num_cpus"
|
name = "num_cpus"
|
||||||
version = "1.13.1"
|
version = "1.13.1"
|
||||||
|
@ -1185,6 +1217,12 @@ version = "6.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
|
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "overload"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot"
|
name = "parking_lot"
|
||||||
version = "0.12.1"
|
version = "0.12.1"
|
||||||
|
@ -1197,15 +1235,15 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot_core"
|
name = "parking_lot_core"
|
||||||
version = "0.9.3"
|
version = "0.9.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
|
checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"windows-sys",
|
"windows-sys 0.42.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1300,9 +1338,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.46"
|
version = "1.0.47"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
|
checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
@ -1530,7 +1568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2"
|
checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"windows-sys",
|
"windows-sys 0.36.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1584,9 +1622,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.85"
|
version = "1.0.86"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44"
|
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
|
@ -1625,6 +1663,15 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-registry"
|
||||||
|
version = "1.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "slab"
|
||||||
version = "0.4.7"
|
version = "0.4.7"
|
||||||
|
@ -1681,10 +1728,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "subprocess"
|
||||||
version = "1.0.101"
|
version = "0.2.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e90cde112c4b9690b8cbe810cba9ddd8bc1d7472e2cae317b69e9438c1cba7d2"
|
checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "1.0.102"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -1741,13 +1798,24 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "text-generation-launcher"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"clap 4.0.17",
|
||||||
|
"ctrlc",
|
||||||
|
"subprocess",
|
||||||
|
"tracing",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"bloom-inference-client",
|
"bloom-inference-client",
|
||||||
"clap 4.0.15",
|
"clap 4.0.17",
|
||||||
"futures",
|
"futures",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -1872,6 +1940,7 @@ dependencies = [
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"signal-hook-registry",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tokio-macros",
|
"tokio-macros",
|
||||||
"winapi",
|
"winapi",
|
||||||
|
@ -1910,9 +1979,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-stream"
|
name = "tokio-stream"
|
||||||
version = "0.1.10"
|
version = "0.1.11"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f6edf2d6bc038a43d31353570e27270603f4648d18f5ed10c0e179abe43255af"
|
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
@ -2031,9 +2100,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-layer"
|
name = "tower-layer"
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
|
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-service"
|
name = "tower-service"
|
||||||
|
@ -2043,9 +2112,9 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing"
|
name = "tracing"
|
||||||
version = "0.1.36"
|
version = "0.1.37"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2fce9567bd60a67d08a16488756721ba392f24f29006402881e43b19aac64307"
|
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"log",
|
"log",
|
||||||
|
@ -2056,9 +2125,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-attributes"
|
name = "tracing-attributes"
|
||||||
version = "0.1.22"
|
version = "0.1.23"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2"
|
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -2067,9 +2136,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-core"
|
name = "tracing-core"
|
||||||
version = "0.1.29"
|
version = "0.1.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5aeea4303076558a00714b823f9ad67d58a3bbda1df83d8827d21193156e22f7"
|
checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"valuable",
|
"valuable",
|
||||||
|
@ -2108,11 +2177,11 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-subscriber"
|
name = "tracing-subscriber"
|
||||||
version = "0.3.15"
|
version = "0.3.16"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "60db860322da191b40952ad9affe65ea23e7dd6a5c442c2c42865810c6ab8e6b"
|
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ansi_term",
|
"nu-ansi-term",
|
||||||
"sharded-slab",
|
"sharded-slab",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"thread_local",
|
"thread_local",
|
||||||
|
@ -2140,9 +2209,9 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.4"
|
version = "1.0.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd"
|
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-normalization"
|
name = "unicode-normalization"
|
||||||
|
@ -2361,43 +2430,100 @@ version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
|
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows_aarch64_msvc",
|
"windows_aarch64_msvc 0.36.1",
|
||||||
"windows_i686_gnu",
|
"windows_i686_gnu 0.36.1",
|
||||||
"windows_i686_msvc",
|
"windows_i686_msvc 0.36.1",
|
||||||
"windows_x86_64_gnu",
|
"windows_x86_64_gnu 0.36.1",
|
||||||
"windows_x86_64_msvc",
|
"windows_x86_64_msvc 0.36.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
|
||||||
|
dependencies = [
|
||||||
|
"windows_aarch64_gnullvm",
|
||||||
|
"windows_aarch64_msvc 0.42.0",
|
||||||
|
"windows_i686_gnu 0.42.0",
|
||||||
|
"windows_i686_msvc 0.42.0",
|
||||||
|
"windows_x86_64_gnu 0.42.0",
|
||||||
|
"windows_x86_64_gnullvm",
|
||||||
|
"windows_x86_64_msvc 0.42.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_gnullvm"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_aarch64_msvc"
|
name = "windows_aarch64_msvc"
|
||||||
version = "0.36.1"
|
version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
|
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_msvc"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_gnu"
|
name = "windows_i686_gnu"
|
||||||
version = "0.36.1"
|
version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
|
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnu"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_i686_msvc"
|
name = "windows_i686_msvc"
|
||||||
version = "0.36.1"
|
version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
|
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_msvc"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_gnu"
|
name = "windows_x86_64_gnu"
|
||||||
version = "0.36.1"
|
version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
|
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnu"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnullvm"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "windows_x86_64_msvc"
|
name = "windows_x86_64_msvc"
|
||||||
version = "0.36.1"
|
version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
|
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_msvc"
|
||||||
|
version = "0.42.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winreg"
|
name = "winreg"
|
||||||
version = "0.10.1"
|
version = "0.10.1"
|
|
@ -0,0 +1,11 @@
|
||||||
|
[workspace]
|
||||||
|
members = [
|
||||||
|
"router",
|
||||||
|
"router/client",
|
||||||
|
"launcher"
|
||||||
|
]
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
debug = 1
|
||||||
|
incremental = true
|
||||||
|
lto = "off"
|
36
Dockerfile
36
Dockerfile
|
@ -1,4 +1,4 @@
|
||||||
FROM rust:1.64 as builder
|
FROM rust:1.64 as router-builder
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
@ -9,7 +9,17 @@ WORKDIR /usr/src/router
|
||||||
|
|
||||||
RUN cargo install --path .
|
RUN cargo install --path .
|
||||||
|
|
||||||
FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
|
FROM rust:1.64 as launcher-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY launcher launcher
|
||||||
|
|
||||||
|
WORKDIR /usr/src/launcher
|
||||||
|
|
||||||
|
RUN cargo install --path .
|
||||||
|
|
||||||
|
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||||
|
|
||||||
ENV LANG=C.UTF-8 \
|
ENV LANG=C.UTF-8 \
|
||||||
LC_ALL=C.UTF-8 \
|
LC_ALL=C.UTF-8 \
|
||||||
|
@ -34,17 +44,15 @@ RUN cd ~ && \
|
||||||
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
|
bash ./Miniconda3-latest-Linux-x86_64.sh -bf -p /opt/miniconda && \
|
||||||
conda create -n text-generation python=3.9 -y
|
conda create -n text-generation python=3.9 -y
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
|
||||||
# Install specific version of torch
|
# Install specific version of torch
|
||||||
RUN /opt/miniconda/envs/text-generation/bin/pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
RUN cd server && make install-torch
|
||||||
|
|
||||||
# Install specific version of transformers
|
# Install specific version of transformers
|
||||||
RUN wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
RUN cd server && make install-transformers
|
||||||
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
|
||||||
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip && \
|
|
||||||
cd transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 && \
|
|
||||||
/opt/miniconda/envs/text-generation/bin/python setup.py install
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY server server
|
COPY server server
|
||||||
|
@ -52,9 +60,7 @@ RUN cd server && \
|
||||||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||||
|
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
COPY --from=launcher-builder /usr/local/cargo/bin/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
COPY run.sh .
|
CMD text-generation-launcher --model-name $MODEL_NAME --num-shard $NUM_GPUS --shard-directory $MODEL_BASE_PATH
|
||||||
RUN chmod +x run.sh
|
|
||||||
|
|
||||||
CMD ["./run.sh"]
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
install-server:
|
||||||
|
cd server && make pip-install
|
||||||
|
|
||||||
|
install-router:
|
||||||
|
cd router && cargo install --path .
|
||||||
|
|
||||||
|
install-launcher:
|
||||||
|
cd launcher && cargo install --path .
|
||||||
|
|
||||||
|
install:
|
||||||
|
make install-server
|
||||||
|
make install-router
|
||||||
|
make install-launcher
|
||||||
|
|
||||||
|
run-bloom-560m:
|
||||||
|
text-generation-launcher --model-name bigscience/bloom-560m --shard-directory /tmp/models --num-shard 2
|
||||||
|
|
||||||
|
run-bloom:
|
||||||
|
text-generation-launcher --model-name bigscience/bloom --shard-directory /tmp/models --num-shard 8
|
55
README.md
55
README.md
|
@ -1,50 +1,51 @@
|
||||||
# Text Generation Inference
|
# LLM Text Generation Inference
|
||||||
|
|
||||||
A Rust and gRPC server for text generation inference.
|
<div align="center">
|
||||||
|
|
||||||
## Load Tests
|
![architecture](assets/architecture.jpg)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
A Rust and gRPC server for large language models text generation inference.
|
||||||
|
|
||||||
|
## Load Tests for BLOOM
|
||||||
|
|
||||||
See `k6/load_test.js`
|
See `k6/load_test.js`
|
||||||
We send the default examples with a 1 second delay between each request.
|
We send the default examples with a 1 second delay between requests.
|
||||||
|
|
||||||
Stages:
|
Stages:
|
||||||
- Ramp up to 50 concurrent requests per second in 1min
|
- Ramp up to 50 vus in 1min
|
||||||
- Ramp up from 50 to 100 concurrent requests per second in 2min
|
- Ramp up from 50 to 100 vus in 2min
|
||||||
- Ramp down to 0 concurrent requests per second in 1min
|
- Ramp down to 0 vus in 1min
|
||||||
|
|
||||||
|
|
||||||
| | avg | min | med | max | p(90) | p(95) | RPS |
|
| | avg | min | med | max | p(90) | p(95) | RPS |
|
||||||
|------------------------|-----------|-----------|-----------|------------|-----------|-----------|----------|
|
|--------------------------------------------------------------|-----------|--------------|-----------|------------|-----------|-----------|----------|
|
||||||
| Original code | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
| [Original code](https://github.com/huggingface/transformers_bloom_parallel) | 8.9s | 1s | 9.12s | 16.69s | 13.7s | 14.26s | 5.9 |
|
||||||
| ISO with original code | 8.88s | 959.53ms | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
|
| ISO with original code | 8.88s | **959.53ms** | 8.89s | 17.08s | 13.34s | 14.12s | 5.94 |
|
||||||
| New batching logic | **5.44s** | **1.27s** | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
| New batching logic | **5.44s** | 1.27s | **5.28s** | **13.12s** | **7.78s** | **8.92s** | **9.08** |
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
cd server
|
make install
|
||||||
pip install .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
## Run
|
||||||
cd router
|
|
||||||
cargo build --release
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models
|
make run-bloom-560m
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Test
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
./router/target/release/router
|
curl 127.0.0.1:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-d '{"inputs":"Testing API","parameters":{"max_new_tokens":9}}' \
|
||||||
|
-H 'Content-Type: application/json'
|
||||||
```
|
```
|
||||||
|
|
||||||
## TODO:
|
## TODO:
|
||||||
|
|
||||||
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
|
- [ ] Add tests for the `server/model` logic
|
||||||
- [ ] Add tests
|
|
||||||
- [ ] Add shutdown logic in router and server
|
|
||||||
- [ ] Improve multi-processing logic in server
|
|
||||||
- [ ] Improve past key layer indexing?
|
|
|
@ -8,7 +8,7 @@ environment_variables:
|
||||||
MODEL_NAME: bigscience/bloom
|
MODEL_NAME: bigscience/bloom
|
||||||
NUM_GPUS: 8
|
NUM_GPUS: 8
|
||||||
environment:
|
environment:
|
||||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
|
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference:0.2
|
||||||
inference_config:
|
inference_config:
|
||||||
liveness_route:
|
liveness_route:
|
||||||
port: 3000
|
port: 3000
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 132 KiB |
|
@ -0,0 +1,13 @@
|
||||||
|
[package]
|
||||||
|
name = "text-generation-launcher"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
authors = ["Olivier Dehaene"]
|
||||||
|
description = "Text Generation Launcher"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
|
ctrlc = "3.2.3"
|
||||||
|
subprocess = "0.2.9"
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-subscriber = "0.3.16"
|
|
@ -0,0 +1,358 @@
|
||||||
|
use clap::Parser;
|
||||||
|
use std::io::{BufRead, BufReader, Read};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::process::ExitCode;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::mpsc::TryRecvError;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::{mpsc, Mutex};
|
||||||
|
use std::thread;
|
||||||
|
use std::thread::sleep;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use std::{fs, io};
|
||||||
|
use subprocess::{Popen, PopenConfig, PopenError, Redirection};
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[clap(default_value = "bigscience/bloom-560m", long, env)]
|
||||||
|
model_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
num_shard: Option<usize>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
shard_directory: Option<String>,
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "1000", long, env)]
|
||||||
|
max_input_length: usize,
|
||||||
|
#[clap(default_value = "32", long, env)]
|
||||||
|
max_batch_size: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_waiting_time: u64,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server", long, env)]
|
||||||
|
shard_uds_path: String,
|
||||||
|
#[clap(default_value = "localhost", long, env)]
|
||||||
|
master_addr: String,
|
||||||
|
#[clap(default_value = "29500", long, env)]
|
||||||
|
master_port: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> ExitCode {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
model_name,
|
||||||
|
num_shard,
|
||||||
|
shard_directory,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_input_length,
|
||||||
|
max_batch_size,
|
||||||
|
max_waiting_time,
|
||||||
|
port,
|
||||||
|
shard_uds_path,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
} = Args::parse();
|
||||||
|
|
||||||
|
// By default we only have one master shard
|
||||||
|
let num_shard = num_shard.unwrap_or(1);
|
||||||
|
|
||||||
|
// Signal handler
|
||||||
|
let running = Arc::new(AtomicBool::new(true));
|
||||||
|
let r = running.clone();
|
||||||
|
ctrlc::set_handler(move || {
|
||||||
|
r.store(false, Ordering::SeqCst);
|
||||||
|
})
|
||||||
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
// Shared shutdown bool
|
||||||
|
let shutdown = Arc::new(Mutex::new(false));
|
||||||
|
// Shared shutdown channel
|
||||||
|
// When shutting down, the main thread will wait for all senders to be dropped
|
||||||
|
let (shutdown_sender, shutdown_receiver) = mpsc::channel();
|
||||||
|
|
||||||
|
// Shared channel to track shard status
|
||||||
|
let (status_sender, status_receiver) = mpsc::channel();
|
||||||
|
|
||||||
|
// Start shard processes
|
||||||
|
for rank in 0..num_shard {
|
||||||
|
let model_name = model_name.clone();
|
||||||
|
let uds_path = shard_uds_path.clone();
|
||||||
|
let shard_directory = shard_directory.clone();
|
||||||
|
let master_addr = master_addr.clone();
|
||||||
|
let status_sender = status_sender.clone();
|
||||||
|
let shutdown = shutdown.clone();
|
||||||
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
|
thread::spawn(move || {
|
||||||
|
shard_manager(
|
||||||
|
model_name,
|
||||||
|
uds_path,
|
||||||
|
shard_directory,
|
||||||
|
rank,
|
||||||
|
num_shard,
|
||||||
|
master_addr,
|
||||||
|
master_port,
|
||||||
|
status_sender,
|
||||||
|
shutdown,
|
||||||
|
shutdown_sender,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
drop(shutdown_sender);
|
||||||
|
|
||||||
|
// Wait for shard to start
|
||||||
|
let mut shard_ready = 0;
|
||||||
|
while running.load(Ordering::SeqCst) {
|
||||||
|
match status_receiver.try_recv() {
|
||||||
|
Ok(ShardStatus::Ready) => {
|
||||||
|
shard_ready += 1;
|
||||||
|
if shard_ready == num_shard {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(TryRecvError::Empty) => {
|
||||||
|
sleep(Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
Ok(ShardStatus::Failed((rank, err))) => {
|
||||||
|
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
Err(TryRecvError::Disconnected) => {
|
||||||
|
tracing::error!("Shard status channel disconnected");
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We might have received a termination signal
|
||||||
|
if !running.load(Ordering::SeqCst) {
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
return ExitCode::SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
// All shard started
|
||||||
|
// Start webserver
|
||||||
|
tracing::info!("Starting Webserver");
|
||||||
|
let mut webserver = match Popen::create(
|
||||||
|
&[
|
||||||
|
"text-generation-router",
|
||||||
|
"--max-concurrent-requests",
|
||||||
|
&max_concurrent_requests.to_string(),
|
||||||
|
"--max-input-length",
|
||||||
|
&max_input_length.to_string(),
|
||||||
|
"--max-batch-size",
|
||||||
|
&max_batch_size.to_string(),
|
||||||
|
"--max-waiting-time",
|
||||||
|
&max_waiting_time.to_string(),
|
||||||
|
"--port",
|
||||||
|
&port.to_string(),
|
||||||
|
"--master-shard-uds-path",
|
||||||
|
&format!("{}-0", shard_uds_path),
|
||||||
|
"--tokenizer-name",
|
||||||
|
&model_name,
|
||||||
|
],
|
||||||
|
PopenConfig {
|
||||||
|
stdout: Redirection::Pipe,
|
||||||
|
stderr: Redirection::Pipe,
|
||||||
|
// Needed for the shutdown procedure
|
||||||
|
setpgid: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(err) => {
|
||||||
|
tracing::error!("Failed to start webserver: {}", err);
|
||||||
|
if let PopenError::IoError(err) = err {
|
||||||
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
|
tracing::error!("text-generation-router not found in PATH");
|
||||||
|
tracing::error!("Please install it with `make install-router`")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Redirect STDOUT and STDERR to the console
|
||||||
|
let webserver_stdout = webserver.stdout.take().unwrap();
|
||||||
|
let webserver_stderr = webserver.stderr.take().unwrap();
|
||||||
|
|
||||||
|
thread::spawn(move || {
|
||||||
|
let stdout = BufReader::new(webserver_stdout);
|
||||||
|
let stderr = BufReader::new(webserver_stderr);
|
||||||
|
for line in stdout.lines() {
|
||||||
|
println!("{}", line.unwrap());
|
||||||
|
}
|
||||||
|
for line in stderr.lines() {
|
||||||
|
println!("{}", line.unwrap());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Default exit code
|
||||||
|
let mut exit_code = ExitCode::SUCCESS;
|
||||||
|
|
||||||
|
while running.load(Ordering::SeqCst) {
|
||||||
|
if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() {
|
||||||
|
tracing::error!("Shard {} failed:\n{}", rank, err);
|
||||||
|
exit_code = ExitCode::FAILURE;
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
match webserver.poll() {
|
||||||
|
Some(_) => {
|
||||||
|
tracing::error!("Webserver Crashed");
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
return ExitCode::FAILURE;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
sleep(Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful termination
|
||||||
|
webserver.terminate().unwrap();
|
||||||
|
tracing::info!("Waiting for webserver to gracefully shutdown");
|
||||||
|
webserver.wait_timeout(Duration::from_secs(90)).unwrap();
|
||||||
|
tracing::info!("Webserver terminated");
|
||||||
|
shutdown_shards(shutdown, &shutdown_receiver);
|
||||||
|
|
||||||
|
exit_code
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ShardStatus {
|
||||||
|
Ready,
|
||||||
|
Failed((usize, String)),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn shard_manager(
|
||||||
|
model_name: String,
|
||||||
|
uds_path: String,
|
||||||
|
shard_directory: Option<String>,
|
||||||
|
rank: usize,
|
||||||
|
world_size: usize,
|
||||||
|
master_addr: String,
|
||||||
|
master_port: usize,
|
||||||
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
|
shutdown: Arc<Mutex<bool>>,
|
||||||
|
_shutdown_sender: mpsc::Sender<()>,
|
||||||
|
) {
|
||||||
|
// Get UDS path
|
||||||
|
let uds_string = format!("{}-{}", uds_path, rank);
|
||||||
|
let uds = Path::new(&uds_string);
|
||||||
|
// Clean previous runs
|
||||||
|
fs::remove_file(uds).unwrap_or_default();
|
||||||
|
|
||||||
|
// Process args
|
||||||
|
let mut shard_argv = vec![
|
||||||
|
"bloom-inference-server".to_string(),
|
||||||
|
"serve".to_string(),
|
||||||
|
model_name,
|
||||||
|
"--uds-path".to_string(),
|
||||||
|
uds_path,
|
||||||
|
];
|
||||||
|
|
||||||
|
if world_size > 1 {
|
||||||
|
shard_argv.push("--sharded".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(shard_directory) = shard_directory {
|
||||||
|
shard_argv.push("--shard-directory".to_string());
|
||||||
|
shard_argv.push(shard_directory);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start process
|
||||||
|
tracing::info!("Starting shard {}", rank);
|
||||||
|
let mut p = match Popen::create(
|
||||||
|
&shard_argv,
|
||||||
|
PopenConfig {
|
||||||
|
stdout: Redirection::Pipe,
|
||||||
|
stderr: Redirection::Pipe,
|
||||||
|
// Needed for the shutdown procedure
|
||||||
|
setpgid: true,
|
||||||
|
// NCCL env vars
|
||||||
|
env: Some(vec![
|
||||||
|
("RANK".parse().unwrap(), rank.to_string().parse().unwrap()),
|
||||||
|
(
|
||||||
|
"WORLD_SIZE".parse().unwrap(),
|
||||||
|
world_size.to_string().parse().unwrap(),
|
||||||
|
),
|
||||||
|
("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()),
|
||||||
|
(
|
||||||
|
"MASTER_PORT".parse().unwrap(),
|
||||||
|
master_port.to_string().parse().unwrap(),
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(err) => {
|
||||||
|
if let PopenError::IoError(ref err) = err {
|
||||||
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
|
tracing::error!("bloom-inference-server not found in PATH");
|
||||||
|
tracing::error!("Please install it with `make install-server`")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
status_sender
|
||||||
|
.send(ShardStatus::Failed((rank, err.to_string())))
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut ready = false;
|
||||||
|
let start_time = Instant::now();
|
||||||
|
loop {
|
||||||
|
// Process exited
|
||||||
|
if p.poll().is_some() {
|
||||||
|
let mut err = String::new();
|
||||||
|
p.stderr.take().unwrap().read_to_string(&mut err).unwrap();
|
||||||
|
status_sender
|
||||||
|
.send(ShardStatus::Failed((rank, err)))
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We received a shutdown signal
|
||||||
|
if *shutdown.lock().unwrap() {
|
||||||
|
p.terminate().unwrap();
|
||||||
|
let _ = p.wait_timeout(Duration::from_secs(90));
|
||||||
|
tracing::info!("Shard {} terminated", rank);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shard is ready
|
||||||
|
if uds.exists() && !ready {
|
||||||
|
tracing::info!("Shard {} ready in {:?}", rank, start_time.elapsed());
|
||||||
|
status_sender.send(ShardStatus::Ready).unwrap();
|
||||||
|
ready = true;
|
||||||
|
} else if !ready {
|
||||||
|
tracing::info!("Waiting for shard {} to be ready...", rank);
|
||||||
|
}
|
||||||
|
sleep(Duration::from_secs(5));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receiver<()>) {
|
||||||
|
tracing::info!("Shutting down shards");
|
||||||
|
// Update shutdown value to true
|
||||||
|
// This will be picked up by the shard manager
|
||||||
|
{
|
||||||
|
let mut shutdown = shutdown.lock().unwrap();
|
||||||
|
*shutdown = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for shards to shutdown
|
||||||
|
// This will block till all shutdown_sender are dropped
|
||||||
|
let _ = shutdown_receiver.recv();
|
||||||
|
}
|
|
@ -11,10 +11,6 @@ service TextGenerationService {
|
||||||
rpc Generate (GenerateRequest) returns (GenerateResponse);
|
rpc Generate (GenerateRequest) returns (GenerateResponse);
|
||||||
/// Generate tokens for a list of cached batches
|
/// Generate tokens for a list of cached batches
|
||||||
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
|
rpc GenerateWithCache (GenerateWithCacheRequest) returns (GenerateWithCacheResponse);
|
||||||
/// Generate tokens until the text of at least one request of the batch is generated
|
|
||||||
rpc GenerateUntilFinished (GenerateUntilFinishedRequest) returns (GenerateUntilFinishedResponse);
|
|
||||||
/// Generate tokens until the text of at least one request of the cached batches i finished
|
|
||||||
rpc GenerateUntilFinishedWithCache (GenerateUntilFinishedWithCacheRequest) returns (GenerateUntilFinishedWithCacheResponse);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
|
@ -92,27 +88,3 @@ message GenerateWithCacheResponse {
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateUntilFinishedRequest {
|
|
||||||
/// Batch
|
|
||||||
Batch batch = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GenerateUntilFinishedResponse {
|
|
||||||
/// Finished requests
|
|
||||||
repeated GeneratedText generated_texts = 1;
|
|
||||||
/// Next batch (cached)
|
|
||||||
optional Batch batch = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GenerateUntilFinishedWithCacheRequest {
|
|
||||||
/// Cached batches
|
|
||||||
repeated Batch batches = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GenerateUntilFinishedWithCacheResponse {
|
|
||||||
/// Finished requests
|
|
||||||
repeated GeneratedText generated_texts = 1;
|
|
||||||
/// Next batch (cached)
|
|
||||||
optional Batch batch = 2;
|
|
||||||
}
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
/target
|
|
|
@ -22,16 +22,7 @@ serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.0"
|
||||||
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
|
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
tracing-subscriber = "0.3.15"
|
tracing-subscriber = "0.3.15"
|
||||||
|
|
||||||
[workspace]
|
|
||||||
members = [
|
|
||||||
"client",
|
|
||||||
]
|
|
||||||
|
|
||||||
[profile.release]
|
|
||||||
debug = 1
|
|
||||||
incremental = true
|
|
||||||
lto = "off"
|
|
||||||
|
|
|
@ -5,8 +5,6 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "^0.3"
|
futures = "^0.3"
|
||||||
#grpc-error-details = { path = "../../grpc-error-details" }
|
|
||||||
#grpc-metadata = { path = "../../grpc-metadata" }
|
|
||||||
prost = "^0.9"
|
prost = "^0.9"
|
||||||
thiserror = "^1.0"
|
thiserror = "^1.0"
|
||||||
tokio = { version = "^1.21", features = ["sync"] }
|
tokio = { version = "^1.21", features = ["sync"] }
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
/// Single shard Client
|
||||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v1::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
|
|
||||||
/// BLOOM Inference gRPC client
|
/// Text Generation Inference gRPC client
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
stub: TextGenerationServiceClient<Channel>,
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
|
@ -34,6 +35,7 @@ impl Client {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a list of uris or unix sockets of all shards
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
let request = tonic::Request::new(ServiceDiscoveryRequest {});
|
let request = tonic::Request::new(ServiceDiscoveryRequest {});
|
||||||
|
@ -46,6 +48,7 @@ impl Client {
|
||||||
.into_inner()
|
.into_inner()
|
||||||
.urls
|
.urls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
// Remove unix socket prefix
|
||||||
.map(|url| match url.strip_prefix("unix://") {
|
.map(|url| match url.strip_prefix("unix://") {
|
||||||
None => url,
|
None => url,
|
||||||
Some(stripped_url) => stripped_url.to_string(),
|
Some(stripped_url) => stripped_url.to_string(),
|
||||||
|
@ -54,6 +57,7 @@ impl Client {
|
||||||
Ok(urls)
|
Ok(urls)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self) -> Result<()> {
|
pub async fn clear_cache(&mut self) -> Result<()> {
|
||||||
let request = tonic::Request::new(ClearCacheRequest {});
|
let request = tonic::Request::new(ClearCacheRequest {});
|
||||||
|
@ -64,6 +68,10 @@ impl Client {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||||
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
|
||||||
|
@ -76,6 +84,10 @@ impl Client {
|
||||||
Ok((response.generated_texts, response.batch))
|
Ok((response.generated_texts, response.batch))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batch
|
||||||
|
///
|
||||||
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
|
/// and the next cached batch
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn generate_with_cache(
|
pub async fn generate_with_cache(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -90,34 +102,4 @@ impl Client {
|
||||||
.into_inner();
|
.into_inner();
|
||||||
Ok((response.generated_texts, response.batch))
|
Ok((response.generated_texts, response.batch))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
pub async fn generate_until_finished(
|
|
||||||
&mut self,
|
|
||||||
batch: Batch,
|
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
|
||||||
let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
|
|
||||||
let response = self
|
|
||||||
.stub
|
|
||||||
.generate_until_finished(request)
|
|
||||||
.instrument(info_span!("generate_until_finished"))
|
|
||||||
.await?
|
|
||||||
.into_inner();
|
|
||||||
Ok((response.generated_texts, response.batch))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
|
||||||
pub async fn generate_until_finished_with_cache(
|
|
||||||
&mut self,
|
|
||||||
batches: Vec<Batch>,
|
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
|
||||||
let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
|
|
||||||
let response = self
|
|
||||||
.stub
|
|
||||||
.generate_until_finished_with_cache(request)
|
|
||||||
.instrument(info_span!("generate_until_finished_with_cache"))
|
|
||||||
.await?
|
|
||||||
.into_inner();
|
|
||||||
Ok((response.generated_texts, response.batch))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
//! BLOOM Inference gRPC client library
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
mod client;
|
mod client;
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
mod pb;
|
mod pb;
|
||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
|
@ -8,7 +9,7 @@ pub use client::Client;
|
||||||
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
pub use tonic::transport;
|
use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
|
@ -21,7 +22,7 @@ pub enum ClientError {
|
||||||
|
|
||||||
impl From<Status> for ClientError {
|
impl From<Status> for ClientError {
|
||||||
fn from(err: Status) -> Self {
|
fn from(err: Status) -> Self {
|
||||||
Self::Generation(err.to_string())
|
Self::Generation(err.message().to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
|
/// Multi shard Client
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, GeneratedText};
|
use crate::{Batch, Client, GeneratedText};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
|
|
||||||
|
/// List of all available commands that can be sent through the command channel
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
enum Command {
|
enum Command {
|
||||||
Generate(
|
Generate(
|
||||||
|
@ -14,36 +16,32 @@ enum Command {
|
||||||
Vec<Batch>,
|
Vec<Batch>,
|
||||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
||||||
),
|
),
|
||||||
GenerateUntilFinished(
|
|
||||||
Batch,
|
|
||||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
|
||||||
),
|
|
||||||
GenerateUntilFinishedWithCache(
|
|
||||||
Vec<Batch>,
|
|
||||||
mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
|
|
||||||
),
|
|
||||||
ClearCache(mpsc::Sender<Result<()>>),
|
ClearCache(mpsc::Sender<Result<()>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tokio task that handles the communication with a single shard
|
||||||
|
///
|
||||||
|
/// We subscribe on a broadcast channel to receive commands that will be sent by
|
||||||
|
/// the ShardedClient.
|
||||||
|
///
|
||||||
|
/// Each command is fan out to all shards.
|
||||||
|
///
|
||||||
|
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
|
||||||
|
/// producer = the shards, single consumer = the ShardedClient).
|
||||||
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
|
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
|
||||||
while let Ok(message) = request_subscriber.recv().await {
|
while let Ok(message) = request_subscriber.recv().await {
|
||||||
match message {
|
match message {
|
||||||
Command::Generate(batch, response_tx) => {
|
Command::Generate(batch, response_tx) => {
|
||||||
let result = client.generate(batch).await;
|
let result = client.generate(batch).await;
|
||||||
|
// We can unwrap_or(()) here because the only error that can happen is if the
|
||||||
|
// receiver is dropped, which means that the ShardedClient already received a
|
||||||
|
// response from another shard
|
||||||
response_tx.try_send(result).unwrap_or(());
|
response_tx.try_send(result).unwrap_or(());
|
||||||
}
|
}
|
||||||
Command::GenerateWithCache(batches, response_tx) => {
|
Command::GenerateWithCache(batches, response_tx) => {
|
||||||
let result = client.generate_with_cache(batches).await;
|
let result = client.generate_with_cache(batches).await;
|
||||||
response_tx.try_send(result).unwrap_or(());
|
response_tx.try_send(result).unwrap_or(());
|
||||||
}
|
}
|
||||||
Command::GenerateUntilFinished(batch, response_tx) => {
|
|
||||||
let result = client.generate_until_finished(batch).await;
|
|
||||||
response_tx.try_send(result).unwrap_or(());
|
|
||||||
}
|
|
||||||
Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
|
|
||||||
let result = client.generate_until_finished_with_cache(batches).await;
|
|
||||||
response_tx.try_send(result).unwrap_or(());
|
|
||||||
}
|
|
||||||
Command::ClearCache(response_tx) => {
|
Command::ClearCache(response_tx) => {
|
||||||
let result = client.clear_cache().await;
|
let result = client.clear_cache().await;
|
||||||
response_tx.try_send(result).unwrap_or(());
|
response_tx.try_send(result).unwrap_or(());
|
||||||
|
@ -52,30 +50,42 @@ async fn client_task(mut client: Client, mut request_subscriber: broadcast::Rece
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
pub struct ShardedClient {
|
pub struct ShardedClient {
|
||||||
|
_clients: Vec<Client>,
|
||||||
request_tx: broadcast::Sender<Command>,
|
request_tx: broadcast::Sender<Command>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShardedClient {
|
impl ShardedClient {
|
||||||
fn new(mut clients: Vec<Client>) -> Self {
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
// The broadcast channel to communicate with the shards
|
||||||
|
// We use a capacity of one as the shards are not asynchronous and can only process one
|
||||||
|
// command at a time
|
||||||
let (request_tx, _) = broadcast::channel(1);
|
let (request_tx, _) = broadcast::channel(1);
|
||||||
|
|
||||||
for client in clients.drain(..) {
|
// Spawn client tasks
|
||||||
|
for client in clients.iter() {
|
||||||
let request_subscriber = request_tx.subscribe();
|
let request_subscriber = request_tx.subscribe();
|
||||||
tokio::spawn(client_task(client, request_subscriber));
|
tokio::spawn(client_task(client.clone(), request_subscriber));
|
||||||
}
|
}
|
||||||
|
|
||||||
Self { request_tx }
|
Self {
|
||||||
|
_clients: clients,
|
||||||
|
request_tx,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
let uris = master_client.service_discovery().await.unwrap();
|
let uris = master_client.service_discovery().await.unwrap();
|
||||||
let futures = uris.into_iter().map(|path| Client::connect_uds(path));
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
Ok(Self::new(clients?))
|
Ok(Self::new(clients?))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a client connected to the given url
|
/// Returns a client connected to the given uri
|
||||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
let master_client = Client::connect(uri).await?;
|
let master_client = Client::connect(uri).await?;
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
|
@ -87,51 +97,43 @@ impl ShardedClient {
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
|
/// and the next cached batch
|
||||||
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||||
|
// Create a channel to receive the response from the shards
|
||||||
|
// We will only ever receive one message on this channel
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||||
self.request_tx
|
self.request_tx
|
||||||
.send(Command::Generate(batch, response_tx))
|
.send(Command::Generate(batch, response_tx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
// As soon as we receive one response, we can return as all shards will return the same
|
||||||
response_rx.recv().await.unwrap()
|
response_rx.recv().await.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batch
|
||||||
|
///
|
||||||
|
/// Returns a list of generated texts of request that met their stopping criteria
|
||||||
|
/// and the next cached batch
|
||||||
pub async fn generate_with_cache(
|
pub async fn generate_with_cache(
|
||||||
&self,
|
&self,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
||||||
|
// Create a channel to receive the response from the shards
|
||||||
|
// We will only ever receive one message on this channel
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||||
self.request_tx
|
self.request_tx
|
||||||
.send(Command::GenerateWithCache(batches, response_tx))
|
.send(Command::GenerateWithCache(batches, response_tx))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
// As soon as we receive one response, we can return as all shards will return the same
|
||||||
response_rx.recv().await.unwrap()
|
response_rx.recv().await.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn generate_until_finished(
|
/// Clear the past generations cache
|
||||||
&self,
|
|
||||||
batch: Batch,
|
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
|
||||||
self.request_tx
|
|
||||||
.send(Command::GenerateUntilFinished(batch, response_tx))
|
|
||||||
.unwrap();
|
|
||||||
response_rx.recv().await.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn generate_until_finished_with_cache(
|
|
||||||
&self,
|
|
||||||
batches: Vec<Batch>,
|
|
||||||
) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
|
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
|
||||||
self.request_tx
|
|
||||||
.send(Command::GenerateUntilFinishedWithCache(
|
|
||||||
batches,
|
|
||||||
response_tx,
|
|
||||||
))
|
|
||||||
.unwrap();
|
|
||||||
response_rx.recv().await.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn clear_cache(&self) -> Result<()> {
|
pub async fn clear_cache(&self) -> Result<()> {
|
||||||
|
// Create a channel to receive the response from the shards
|
||||||
|
// We will only ever receive one message on this channel
|
||||||
let (response_tx, mut response_rx) = mpsc::channel(1);
|
let (response_tx, mut response_rx) = mpsc::channel(1);
|
||||||
self.request_tx
|
self.request_tx
|
||||||
.send(Command::ClearCache(response_tx))
|
.send(Command::ClearCache(response_tx))
|
||||||
|
|
|
@ -1,129 +1,158 @@
|
||||||
use crate::server::GenerateRequest;
|
/// Batching and inference logic
|
||||||
|
use crate::GenerateRequest;
|
||||||
use crate::{Db, Entry};
|
use crate::{Db, Entry};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{oneshot, Notify};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
const MAX_LENGTH: usize = 128;
|
/// Batcher
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum InferError {
|
|
||||||
#[error("Request failed during generation: {0}")]
|
|
||||||
GenerationError(String),
|
|
||||||
#[error("Model is overloaded")]
|
|
||||||
Overloaded,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<InferError> for (StatusCode, String) {
|
|
||||||
fn from(err: InferError) -> Self {
|
|
||||||
match err {
|
|
||||||
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
|
||||||
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Batcher {
|
pub struct Batcher {
|
||||||
|
/// Request database
|
||||||
db: Db,
|
db: Db,
|
||||||
|
/// Shared state
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Batcher shared state
|
||||||
struct Shared {
|
struct Shared {
|
||||||
|
/// Batching background Tokio task notifier
|
||||||
batching_task: Notify,
|
batching_task: Notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Batcher {
|
impl Batcher {
|
||||||
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
max_batch_size: usize,
|
||||||
|
max_waiting_time: Duration,
|
||||||
|
) -> Self {
|
||||||
|
// Batcher shared state
|
||||||
let db = Db::new();
|
let db = Db::new();
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
});
|
});
|
||||||
|
|
||||||
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
max_batch_size,
|
||||||
|
max_waiting_time,
|
||||||
|
client,
|
||||||
|
db.clone(),
|
||||||
|
shared.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
Self { db, shared }
|
Self { db, shared }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a new request to the database and return a future that will generate the text
|
||||||
pub(crate) async fn infer(
|
pub(crate) async fn infer(
|
||||||
&self,
|
&self,
|
||||||
input_length: usize,
|
input_length: usize,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
if self.db.len() > MAX_LENGTH {
|
// One shot channel to communicate with the background batching task
|
||||||
return Err(InferError::Overloaded);
|
let (response_tx, response_rx) = oneshot::channel();
|
||||||
}
|
|
||||||
let (request_tx, request_rx) = oneshot::channel();
|
// Try to append the request to the database
|
||||||
self.db.append(Entry {
|
self.db.append(Entry {
|
||||||
request,
|
request,
|
||||||
response_tx: request_tx,
|
response_tx,
|
||||||
input_length,
|
input_length,
|
||||||
|
time: Instant::now(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the database that needs
|
||||||
|
// to be batched
|
||||||
self.shared.batching_task.notify_waiters();
|
self.shared.batching_task.notify_waiters();
|
||||||
match request_rx.await.unwrap() {
|
|
||||||
|
// Await on the response from the background task
|
||||||
|
// We can safely unwrap as the background task will never drop the sender
|
||||||
|
match response_rx.await.unwrap() {
|
||||||
Ok(output) => Ok(output),
|
Ok(output) => Ok(output),
|
||||||
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn batching_task(max_batch_size: usize,
|
/// Batching logic
|
||||||
client: ShardedClient,
|
/// Will be launched in a background Tokio task
|
||||||
db: Db,
|
///
|
||||||
shared: Arc<Shared>) {
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[instrument(skip(client, db, shared))]
|
||||||
|
async fn batching_task(
|
||||||
|
max_batch_size: usize,
|
||||||
|
max_waiting_time: Duration,
|
||||||
|
client: ShardedClient,
|
||||||
|
db: Db,
|
||||||
|
shared: Arc<Shared>,
|
||||||
|
) {
|
||||||
|
// Minimum batch size after which we try to add more requests
|
||||||
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
||||||
|
|
||||||
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
|
// Wait for a notification from the Batcher struct
|
||||||
shared.batching_task.notified().await;
|
shared.batching_task.notified().await;
|
||||||
|
|
||||||
if let Some(batch) = db.next_batch(max_batch_size) {
|
// Get the next batch from the DB
|
||||||
let request_ids = batch.requests.iter().map(|req| req.id).collect();
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
let mut cached_batch = match batch.size {
|
// waiting in the DB
|
||||||
size if size > limit_min_batch_size => {
|
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
|
||||||
wrap_future(client.generate_until_finished(batch), request_ids, &db).await
|
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
|
||||||
}
|
|
||||||
_ => wrap_future(client.generate(batch), request_ids, &db).await,
|
|
||||||
};
|
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
while let Some(batch) = cached_batch {
|
while let Some(batch) = cached_batch {
|
||||||
let mut current_batch_size = batch.size;
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
|
|
||||||
if current_batch_size <= limit_min_batch_size {
|
// If the current batch is too small, we try to add more requests to it
|
||||||
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
|
if batch_size <= limit_min_batch_size {
|
||||||
let new_batch_request_ids =
|
// Get the next batch from the DB that meet our minimum size criteria
|
||||||
new_batch.requests.iter().map(|req| req.id).collect();
|
if let Some((new_request_ids, new_batch)) =
|
||||||
|
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
|
||||||
|
{
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
|
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||||
.await;
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we don't have enough requests to meet the minimum size criteria, we
|
||||||
|
// try to get the next batch from the DB that have been waiting over
|
||||||
|
// the max_waiting_time
|
||||||
|
else if let Some((new_request_ids, new_batch)) =
|
||||||
|
db.next_batch(None, max_batch_size, Some(max_waiting_time))
|
||||||
|
{
|
||||||
|
let new_cached_batch =
|
||||||
|
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
|
||||||
|
// Extend current batch with the new batch
|
||||||
if let Some(new_cached_batch) = new_cached_batch {
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
current_batch_size += new_cached_batch.size;
|
|
||||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||||
batches.push(new_cached_batch);
|
batches.push(new_cached_batch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cached_batch = match current_batch_size {
|
cached_batch =
|
||||||
size if size > limit_min_batch_size => {
|
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
|
||||||
wrap_future(
|
|
||||||
client.generate_until_finished_with_cache(batches),
|
|
||||||
request_ids,
|
|
||||||
&db,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
_ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
|
||||||
async fn wrap_future(
|
async fn wrap_future(
|
||||||
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
|
||||||
request_ids: Vec<u64>,
|
request_ids: Vec<u64>,
|
||||||
|
@ -134,6 +163,7 @@ async fn wrap_future(
|
||||||
send_generated(generated_texts, db);
|
send_generated(generated_texts, db);
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
send_error(err, request_ids, db);
|
send_error(err, request_ids, db);
|
||||||
None
|
None
|
||||||
|
@ -141,16 +171,20 @@ async fn wrap_future(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send errors to the Batcher for all `request_ids`
|
||||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
||||||
request_ids.into_iter().for_each(|id| {
|
request_ids.into_iter().for_each(|id| {
|
||||||
|
// We can `expect` here as the request id should always be in the DB
|
||||||
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send `generated_text` to the Batcher for all `finished`
|
||||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
|
// We can `expect` here as the request id should always be in the DB
|
||||||
let entry = db
|
let entry = db
|
||||||
.remove(&output.request.unwrap().id)
|
.remove(&output.request.unwrap().id)
|
||||||
.expect("ID not found in db. This is a bug.");
|
.expect("ID not found in db. This is a bug.");
|
||||||
|
@ -158,3 +192,18 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||||
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InferError {
|
||||||
|
#[error("Request failed during generation: {0}")]
|
||||||
|
GenerationError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert to Axum supported format
|
||||||
|
impl From<InferError> for (StatusCode, String) {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
match err {
|
||||||
|
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
287
router/src/db.rs
287
router/src/db.rs
|
@ -1,16 +1,173 @@
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::server::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
/// Database entry
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct Entry {
|
pub(crate) struct Entry {
|
||||||
|
/// Request
|
||||||
pub request: GenerateRequest,
|
pub request: GenerateRequest,
|
||||||
|
/// Response sender to communicate between the Batcher and the batching_task
|
||||||
pub response_tx: Sender<Result<String, ClientError>>,
|
pub response_tx: Sender<Result<String, ClientError>>,
|
||||||
|
/// Number of tokens in the input
|
||||||
pub input_length: usize,
|
pub input_length: usize,
|
||||||
|
/// Instant when this entry was created
|
||||||
|
pub time: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Request Database
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Db {
|
||||||
|
pub shared: Arc<Shared>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared state
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Shared {
|
||||||
|
state: Mutex<State>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Database State
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct State {
|
||||||
|
/// Database entries organized in a BTreeMap to be able to iterate over them in order
|
||||||
|
entries: BTreeMap<u64, Entry>,
|
||||||
|
|
||||||
|
/// Id of the next entry
|
||||||
|
next_id: u64,
|
||||||
|
|
||||||
|
/// Id of the next batch
|
||||||
|
next_batch_id: u64,
|
||||||
|
|
||||||
|
/// Start ID of the next batch. Used to iterate inside the entries BTreeMap
|
||||||
|
next_batch_start_id: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
/// Get the next requests
|
||||||
|
fn next_requests(
|
||||||
|
&self,
|
||||||
|
max_size: usize,
|
||||||
|
min_waiting_time: Option<Duration>,
|
||||||
|
) -> Option<(Vec<u64>, Vec<Request>)> {
|
||||||
|
// Iterates for max_size over the BTreemap starting from next_batch_start_id
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
let mut ids = Vec::new();
|
||||||
|
|
||||||
|
for (id, entry) in self
|
||||||
|
.entries
|
||||||
|
// Start from next_batch_start_id
|
||||||
|
.range(self.next_batch_start_id..)
|
||||||
|
// Take max_size
|
||||||
|
.take(max_size)
|
||||||
|
{
|
||||||
|
if let Some(min_waiting_time) = min_waiting_time {
|
||||||
|
// Only take entries that waited for at least min_waiting_time
|
||||||
|
if entry.time.elapsed() < min_waiting_time {
|
||||||
|
// Since entries are ordered, we already know that all following entries won't
|
||||||
|
// satisfy the condition
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
requests.push(Request {
|
||||||
|
id: *id,
|
||||||
|
inputs: entry.request.inputs.clone(),
|
||||||
|
input_length: entry.input_length as u32,
|
||||||
|
parameters: Some(LogitsWarperParameters::from(
|
||||||
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
|
max_new_tokens: entry.request.parameters.max_new_tokens,
|
||||||
|
});
|
||||||
|
|
||||||
|
ids.push(*id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if requests.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some((ids, requests))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Db {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// Shared state
|
||||||
|
let shared = Arc::new(Shared {
|
||||||
|
state: Mutex::new(State {
|
||||||
|
entries: BTreeMap::new(),
|
||||||
|
next_id: 0,
|
||||||
|
next_batch_id: 0,
|
||||||
|
next_batch_start_id: 0,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
Self { shared }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append an entry to the database
|
||||||
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
|
// Acquire lock
|
||||||
|
let mut state = self.shared.state.lock();
|
||||||
|
|
||||||
|
// Insert entry
|
||||||
|
let id = state.next_id;
|
||||||
|
state.next_id += 1;
|
||||||
|
state.entries.insert(id, entry);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove an entry from the database if it exists
|
||||||
|
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
||||||
|
let mut state = self.shared.state.lock();
|
||||||
|
state.entries.remove(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next batch
|
||||||
|
pub(crate) fn next_batch(
|
||||||
|
&self,
|
||||||
|
min_size: Option<usize>,
|
||||||
|
max_size: usize,
|
||||||
|
min_waiting_time: Option<Duration>,
|
||||||
|
) -> Option<(Vec<u64>, Batch)> {
|
||||||
|
// Acquire lock
|
||||||
|
let mut state = self.shared.state.lock();
|
||||||
|
|
||||||
|
// Get requests from the database
|
||||||
|
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
// If min_size is set, only return a batch if there are enough requests
|
||||||
|
if requests.len() < min_size {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch size
|
||||||
|
let size = requests.len();
|
||||||
|
// Longest input length for all requests in batch size
|
||||||
|
// Used for padding inside the inference server
|
||||||
|
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
||||||
|
let batch = Batch {
|
||||||
|
id: state.next_batch_id,
|
||||||
|
requests,
|
||||||
|
size: size as u32,
|
||||||
|
max_sequence_length,
|
||||||
|
};
|
||||||
|
// Update next_batch_start_id to the last id in the batch + 1
|
||||||
|
state.next_batch_start_id = ids.last().unwrap() + 1;
|
||||||
|
// Increment batch id
|
||||||
|
state.next_batch_id += 1;
|
||||||
|
|
||||||
|
return Some((ids, batch));
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<GenerateParameters> for LogitsWarperParameters {
|
impl From<GenerateParameters> for LogitsWarperParameters {
|
||||||
|
@ -23,129 +180,3 @@ impl From<GenerateParameters> for LogitsWarperParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct Db {
|
|
||||||
pub shared: Arc<Shared>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Shared {
|
|
||||||
state: RwLock<State>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct State {
|
|
||||||
entries: BTreeMap<u64, Entry>,
|
|
||||||
|
|
||||||
/// Identifier to use for the next expiration. Each expiration is associated
|
|
||||||
/// with a unique identifier. See above for why.
|
|
||||||
next_id: u64,
|
|
||||||
|
|
||||||
next_batch_id: u64,
|
|
||||||
|
|
||||||
/// Current batch id
|
|
||||||
next_batch_start_id: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Db {
|
|
||||||
pub(crate) fn new() -> Self {
|
|
||||||
let shared = Arc::new(Shared {
|
|
||||||
state: RwLock::new(State {
|
|
||||||
entries: BTreeMap::new(),
|
|
||||||
next_id: 0,
|
|
||||||
next_batch_id: 0,
|
|
||||||
next_batch_start_id: 0,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
Self { shared }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
|
||||||
let mut state = self.shared.state.write();
|
|
||||||
|
|
||||||
let id = state.next_id;
|
|
||||||
state.next_id += 1;
|
|
||||||
|
|
||||||
state.entries.insert(id, entry);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
|
||||||
let mut state = self.shared.state.write();
|
|
||||||
state.entries.remove(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn len(&self) -> usize {
|
|
||||||
let state = self.shared.state.read();
|
|
||||||
state.entries.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn next_requests(&self, max_size: usize) -> Option<(u64, Vec<Request>)> {
|
|
||||||
let state = self.shared.state.read();
|
|
||||||
|
|
||||||
let requests: Vec<Request> = state
|
|
||||||
.entries
|
|
||||||
.range(state.next_batch_start_id..)
|
|
||||||
.take(max_size)
|
|
||||||
.map(|(id, entry)| Request {
|
|
||||||
id: *id,
|
|
||||||
inputs: entry.request.inputs.clone(),
|
|
||||||
input_length: entry.input_length as u32,
|
|
||||||
parameters: Some(LogitsWarperParameters::from(
|
|
||||||
entry.request.parameters.clone(),
|
|
||||||
)),
|
|
||||||
max_new_tokens: entry.request.parameters.max_new_tokens,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if requests.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let last_id = requests.last().unwrap().id;
|
|
||||||
Some((last_id, requests))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn next_batch(&self, max_size: usize) -> Option<Batch> {
|
|
||||||
if let Some((last_id, requests)) = self.next_requests(max_size) {
|
|
||||||
let mut state = self.shared.state.write();
|
|
||||||
let size = requests.len();
|
|
||||||
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
|
||||||
let batch = Batch {
|
|
||||||
id: state.next_batch_id,
|
|
||||||
requests,
|
|
||||||
size: size as u32,
|
|
||||||
max_sequence_length,
|
|
||||||
};
|
|
||||||
state.next_batch_start_id = last_id + 1;
|
|
||||||
state.next_batch_id += 1;
|
|
||||||
return Some(batch);
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn next_batch_minimum_size(
|
|
||||||
&self,
|
|
||||||
min_size: usize,
|
|
||||||
max_size: usize,
|
|
||||||
) -> Option<Batch> {
|
|
||||||
if let Some((last_id, requests)) = self.next_requests(max_size) {
|
|
||||||
if requests.len() >= min_size {
|
|
||||||
let mut state = self.shared.state.write();
|
|
||||||
let size = requests.len();
|
|
||||||
let max_sequence_length = requests.iter().map(|r| r.input_length).max().unwrap();
|
|
||||||
let batch = Batch {
|
|
||||||
id: state.next_batch_id,
|
|
||||||
requests,
|
|
||||||
size: size as u32,
|
|
||||||
max_sequence_length,
|
|
||||||
};
|
|
||||||
state.next_batch_start_id = last_id + 1;
|
|
||||||
state.next_batch_id += 1;
|
|
||||||
return Some(batch);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,8 +1,68 @@
|
||||||
|
/// Text Generation Inference Webserver
|
||||||
mod batcher;
|
mod batcher;
|
||||||
mod db;
|
mod db;
|
||||||
mod validation;
|
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
mod validation;
|
||||||
|
|
||||||
use db::{Db, Entry};
|
|
||||||
use batcher::Batcher;
|
use batcher::Batcher;
|
||||||
|
use db::{Db, Entry};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use validation::Validation;
|
use validation::Validation;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub(crate) struct GenerateParameters {
|
||||||
|
#[serde(default = "default_temperature")]
|
||||||
|
pub temperature: f32,
|
||||||
|
#[serde(default = "default_top_k")]
|
||||||
|
pub top_k: i32,
|
||||||
|
#[serde(default = "default_top_p")]
|
||||||
|
pub top_p: f32,
|
||||||
|
#[serde(default = "default_do_sample")]
|
||||||
|
pub do_sample: bool,
|
||||||
|
#[serde(default = "default_max_new_tokens")]
|
||||||
|
pub max_new_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_temperature() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_top_k() -> i32 {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_top_p() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_do_sample() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_new_tokens() -> u32 {
|
||||||
|
20
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_parameters() -> GenerateParameters {
|
||||||
|
GenerateParameters {
|
||||||
|
temperature: default_temperature(),
|
||||||
|
top_k: default_top_k(),
|
||||||
|
top_p: default_top_p(),
|
||||||
|
do_sample: default_do_sample(),
|
||||||
|
max_new_tokens: default_max_new_tokens(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub(crate) struct GenerateRequest {
|
||||||
|
pub inputs: String,
|
||||||
|
#[serde(default = "default_parameters")]
|
||||||
|
pub parameters: GenerateParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub(crate) struct GeneratedText {
|
||||||
|
pub generated_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) type GenerateResponse = Vec<GeneratedText>;
|
||||||
|
|
|
@ -1,37 +1,61 @@
|
||||||
|
/// Text Generation Inference webserver entrypoint
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
|
use clap::Parser;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[clap(author, version, about, long_about = None)]
|
#[clap(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "32", long, short, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "1000", long, env)]
|
||||||
|
max_input_length: usize,
|
||||||
|
#[clap(default_value = "32", long, env)]
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_waiting_time: u64,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
||||||
shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), std::io::Error> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_input_length,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
max_waiting_time,
|
||||||
port,
|
port,
|
||||||
shard_uds_path,
|
master_shard_uds_path,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
validation_workers,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
if validation_workers == 1 {
|
||||||
|
panic!("validation_workers must be > 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_waiting_time = Duration::from_secs(max_waiting_time);
|
||||||
|
|
||||||
|
// Download and instantiate tokenizer
|
||||||
|
// This will only be used to validate payloads
|
||||||
|
//
|
||||||
|
// We need to download it outside of the Tokio runtime
|
||||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
||||||
|
|
||||||
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()
|
||||||
|
@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
let sharded_client = ShardedClient::connect_uds(shard_uds_path)
|
// Instantiate sharded client from the master unix socket
|
||||||
|
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.expect("Could not connect to server");
|
.expect("Could not connect to server");
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache()
|
.clear_cache()
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.expect("Unable to clear cache");
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
// Binds on localhost
|
||||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
||||||
|
|
||||||
server::run(max_batch_size, sharded_client, tokenizer, addr).await;
|
// Run server
|
||||||
|
server::run(
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_input_length,
|
||||||
|
max_batch_size,
|
||||||
|
max_waiting_time,
|
||||||
|
sharded_client,
|
||||||
|
tokenizer,
|
||||||
|
validation_workers,
|
||||||
|
addr,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,68 +1,44 @@
|
||||||
use crate::{Batcher, Validation};
|
use crate::{
|
||||||
|
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
|
||||||
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
use serde::Deserialize;
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::signal;
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
// Server shared state
|
||||||
pub(crate) struct GenerateParameters {
|
#[derive(Clone)]
|
||||||
#[serde(default = "default_temperature")]
|
struct ServerState {
|
||||||
pub temperature: f32,
|
validation: Validation,
|
||||||
#[serde(default = "default_top_k")]
|
batcher: Batcher,
|
||||||
pub top_k: i32,
|
limit_concurrent_requests: Arc<Semaphore>,
|
||||||
#[serde(default = "default_top_p")]
|
|
||||||
pub top_p: f32,
|
|
||||||
#[serde(default = "default_do_sample")]
|
|
||||||
pub do_sample: bool,
|
|
||||||
#[serde(default = "default_max_new_tokens")]
|
|
||||||
pub max_new_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_temperature() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_k() -> i32 {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_p() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_do_sample() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
|
||||||
20
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
|
||||||
GenerateParameters {
|
|
||||||
temperature: default_temperature(),
|
|
||||||
top_k: default_top_k(),
|
|
||||||
top_p: default_top_p(),
|
|
||||||
do_sample: default_do_sample(),
|
|
||||||
max_new_tokens: default_max_new_tokens(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
|
||||||
pub(crate) struct GenerateRequest {
|
|
||||||
pub inputs: String,
|
|
||||||
#[serde(default = "default_parameters")]
|
|
||||||
pub parameters: GenerateParameters,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Health check method
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
||||||
|
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||||
|
// be a bit too slow for a health check.
|
||||||
|
// What we should do instead if check if the gRPC channels are still healthy.
|
||||||
|
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
"Model is overloaded".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Send a small inference request
|
||||||
state
|
state
|
||||||
.batcher
|
.batcher
|
||||||
.infer(
|
.infer(
|
||||||
|
@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate method
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
state: Extension<ServerState>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
|
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::TOO_MANY_REQUESTS,
|
||||||
|
"Model is overloaded".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Validate request
|
||||||
let (input_length, validated_request) = state
|
let (input_length, validated_request) = state
|
||||||
.validation
|
.validation
|
||||||
|
// FIXME: can't we get rid of the cloning here??
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: req.inputs.clone(),
|
inputs: req.inputs.clone(),
|
||||||
parameters: req.parameters.clone(),
|
parameters: req.parameters.clone(),
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Inference
|
||||||
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
||||||
|
|
||||||
|
// Tracing metadata
|
||||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||||
tracing::Span::current().record(
|
tracing::Span::current().record(
|
||||||
"time_per_token",
|
"time_per_token",
|
||||||
|
@ -106,31 +94,71 @@ async fn generate(
|
||||||
);
|
);
|
||||||
tracing::info!("response: {}", generated_text);
|
tracing::info!("response: {}", generated_text);
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({
|
// Send response
|
||||||
"generated_text": generated_text,
|
let response = vec![GeneratedText { generated_text }];
|
||||||
})))
|
Ok(Json(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
/// Serving method
|
||||||
struct ServerState {
|
#[allow(clippy::too_many_arguments)]
|
||||||
validation: Validation,
|
pub async fn run(
|
||||||
batcher: Batcher,
|
max_concurrent_requests: usize,
|
||||||
}
|
max_input_length: usize,
|
||||||
|
max_batch_size: usize,
|
||||||
pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
|
max_waiting_time: Duration,
|
||||||
let batcher = Batcher::new(client, max_batch_size);
|
client: ShardedClient,
|
||||||
let validation = Validation::new(tokenizer);
|
tokenizer: Tokenizer,
|
||||||
|
validation_workers: usize,
|
||||||
let shared_state = ServerState { validation, batcher };
|
addr: SocketAddr,
|
||||||
|
) {
|
||||||
|
// Create state
|
||||||
|
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
|
||||||
|
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
|
||||||
|
let shared_state = ServerState {
|
||||||
|
validation,
|
||||||
|
batcher,
|
||||||
|
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.layer(Extension(shared_state.clone()))
|
.layer(Extension(shared_state.clone()))
|
||||||
.route("/health", get(liveness))
|
.route("/health", get(health))
|
||||||
.layer(Extension(shared_state.clone()));
|
.layer(Extension(shared_state.clone()));
|
||||||
|
|
||||||
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
|
// Wait until all requests are finished to shut down
|
||||||
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Shutdown signal handler
|
||||||
|
async fn shutdown_signal() {
|
||||||
|
let ctrl_c = async {
|
||||||
|
signal::ctrl_c()
|
||||||
|
.await
|
||||||
|
.expect("failed to install Ctrl+C handler");
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
let terminate = async {
|
||||||
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
|
.expect("failed to install signal handler")
|
||||||
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let terminate = std::future::pending::<()>();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => {},
|
||||||
|
_ = terminate => {},
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("signal received, starting graceful shutdown");
|
||||||
|
}
|
||||||
|
|
|
@ -1,62 +1,105 @@
|
||||||
use crate::server::GenerateRequest;
|
/// Payload validation logic
|
||||||
|
use crate::GenerateRequest;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
|
use tokenizers::{
|
||||||
|
DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
|
||||||
|
TokenizerImpl,
|
||||||
|
};
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
/// Validation
|
||||||
pub enum ValidationError {
|
|
||||||
#[error("Temperature must be strictly positive")]
|
|
||||||
Temperature,
|
|
||||||
#[error("Top p must be <= 0.0 or > 1.0")]
|
|
||||||
TopP,
|
|
||||||
#[error("Top k must be strictly positive")]
|
|
||||||
TopK,
|
|
||||||
#[error("Max New Tokens must be < 512")]
|
|
||||||
MaxNewTokens,
|
|
||||||
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
|
|
||||||
InputLength(usize),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ValidationError> for (StatusCode, String) {
|
|
||||||
fn from(err: ValidationError) -> Self {
|
|
||||||
(StatusCode::BAD_REQUEST, err.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ValidationRequest = (
|
|
||||||
GenerateRequest,
|
|
||||||
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
|
/// Channel to communicate with the background validation task
|
||||||
sender: mpsc::Sender<ValidationRequest>,
|
sender: mpsc::Sender<ValidationRequest>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
pub(crate) fn new(tokenizer: Tokenizer) -> Self {
|
pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
|
||||||
|
// Crate channel
|
||||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
||||||
|
|
||||||
tokio::spawn(validation_task(tokenizer, validation_receiver));
|
// Launch background validation task
|
||||||
|
tokio::spawn(validation_task(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_input_length,
|
||||||
|
validation_receiver,
|
||||||
|
));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
sender: validation_sender,
|
sender: validation_sender,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate a payload and get the number of tokens in the input
|
||||||
pub(crate) async fn validate(
|
pub(crate) async fn validate(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<(usize, GenerateRequest), ValidationError> {
|
) -> Result<(usize, GenerateRequest), ValidationError> {
|
||||||
|
// Create response channel
|
||||||
let (sender, receiver) = oneshot::channel();
|
let (sender, receiver) = oneshot::channel();
|
||||||
|
// Send request to the background validation task
|
||||||
|
// Unwrap is safe here
|
||||||
self.sender.send((request, sender)).await.unwrap();
|
self.sender.send((request, sender)).await.unwrap();
|
||||||
|
// Await on response channel
|
||||||
|
// Unwrap is safe here
|
||||||
receiver.await.unwrap()
|
receiver.await.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
|
/// Validation task
|
||||||
while let Some((request, response_tx)) = receiver.recv().await {
|
/// Load balance the validation requests between multiple validation workers
|
||||||
|
async fn validation_task(
|
||||||
|
workers: usize,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
max_input_length: usize,
|
||||||
|
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||||
|
) {
|
||||||
|
let mut workers_senders = Vec::with_capacity(workers);
|
||||||
|
|
||||||
|
// Create workers
|
||||||
|
for _ in 0..workers {
|
||||||
|
let tokenizer_clone = tokenizer.clone();
|
||||||
|
// Create channel to communicate with worker
|
||||||
|
let (worker_sender, worker_receiver) = mpsc::channel(workers);
|
||||||
|
workers_senders.push(worker_sender);
|
||||||
|
|
||||||
|
// Spawn worker
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Load balance requests between workers
|
||||||
|
for sender in workers_senders.iter() {
|
||||||
|
if let Some(validation_request) = receiver.recv().await {
|
||||||
|
sender.send(validation_request).await.unwrap();
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check the parameters inside the payload and get the number of tokens inside the input using
|
||||||
|
/// the tokenizer
|
||||||
|
fn validation_worker(
|
||||||
|
tokenizer: TokenizerImpl<
|
||||||
|
ModelWrapper,
|
||||||
|
NormalizerWrapper,
|
||||||
|
PreTokenizerWrapper,
|
||||||
|
PostProcessorWrapper,
|
||||||
|
DecoderWrapper,
|
||||||
|
>,
|
||||||
|
max_input_length: usize,
|
||||||
|
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||||
|
) {
|
||||||
|
// Loop over requests
|
||||||
|
while let Some((request, response_tx)) = receiver.blocking_recv() {
|
||||||
if request.parameters.temperature < 0.0 {
|
if request.parameters.temperature < 0.0 {
|
||||||
response_tx
|
response_tx
|
||||||
.send(Err(ValidationError::Temperature))
|
.send(Err(ValidationError::Temperature))
|
||||||
|
@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the number of tokens in the input
|
||||||
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
|
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
|
||||||
let input_length = inputs.len();
|
let input_length = inputs.len();
|
||||||
|
|
||||||
if input_length > 1000 {
|
if input_length > max_input_length {
|
||||||
response_tx
|
response_tx
|
||||||
.send(Err(ValidationError::InputLength(input_length)))
|
.send(Err(ValidationError::InputLength(input_length)))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
|
@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
||||||
response_tx.send(Ok((input_length, request))).unwrap_or(());
|
response_tx.send(Ok((input_length, request))).unwrap_or(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ValidationRequest = (
|
||||||
|
GenerateRequest,
|
||||||
|
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
|
||||||
|
);
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum ValidationError {
|
||||||
|
#[error("Temperature must be strictly positive")]
|
||||||
|
Temperature,
|
||||||
|
#[error("Top p must be <= 0.0 or > 1.0")]
|
||||||
|
TopP,
|
||||||
|
#[error("Top k must be strictly positive")]
|
||||||
|
TopK,
|
||||||
|
#[error("Max New Tokens must be < 512")]
|
||||||
|
MaxNewTokens,
|
||||||
|
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
|
||||||
|
InputLength(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidationError> for (StatusCode, String) {
|
||||||
|
fn from(err: ValidationError) -> Self {
|
||||||
|
(StatusCode::BAD_REQUEST, err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
30
run.sh
30
run.sh
|
@ -1,30 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
|
||||||
|
|
||||||
# Run in background
|
|
||||||
$server_cmd 2>&1 > /dev/null &
|
|
||||||
|
|
||||||
# Check if server is running by checking if the unix socket is created
|
|
||||||
FILE=/tmp/bloom-inference-0
|
|
||||||
while :
|
|
||||||
do
|
|
||||||
if test -S "$FILE"; then
|
|
||||||
echo "Text Generation Python gRPC server started"
|
|
||||||
break
|
|
||||||
else
|
|
||||||
echo "Waiting for Text Generation Python gRPC server to start"
|
|
||||||
sleep 5
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
sleep 1
|
|
||||||
|
|
||||||
# Run in background
|
|
||||||
text-generation-router &
|
|
||||||
|
|
||||||
# Wait for any process to exit
|
|
||||||
wait -n
|
|
||||||
|
|
||||||
# Exit with status of process that exited first
|
|
||||||
exit $?
|
|
|
@ -0,0 +1,155 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
bloom_inference/__pycache__/
|
||||||
|
bloom_inference/pb/__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
|
@ -4,17 +4,28 @@ gen-server:
|
||||||
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch bloom_inference/pb/__init__.py
|
touch bloom_inference/pb/__init__.py
|
||||||
|
|
||||||
unit-tests:
|
install-transformers:
|
||||||
python -m pytest --cov=bloom_inference tests
|
# Install specific version of transformers
|
||||||
|
rm transformers || true
|
||||||
|
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
|
||||||
|
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
|
||||||
|
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
|
||||||
|
mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers
|
||||||
|
cd transformers && python setup.py install
|
||||||
|
|
||||||
unit-tests-reporting:
|
install-torch:
|
||||||
python -m pytest --junitxml=report.xml --cov=bloom_inference tests
|
# Install specific version of torch
|
||||||
|
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
|
||||||
|
|
||||||
pip-install:
|
pip-install:
|
||||||
pip install grpcio-tools
|
pip install grpcio-tools
|
||||||
make gen-server
|
make gen-server
|
||||||
|
make install-torch
|
||||||
|
make install-transformers
|
||||||
pip install .
|
pip install .
|
||||||
|
|
||||||
install:
|
install:
|
||||||
poetry install
|
poetry install
|
||||||
make gen-server
|
make gen-server
|
||||||
|
make install-torch
|
||||||
|
make install-transformers
|
||||||
|
|
|
@ -1,41 +1,51 @@
|
||||||
|
import os
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch.distributed.launcher import launch_agent, LaunchConfig
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from bloom_inference import server
|
from bloom_inference import prepare_weights, server
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def launcher(
|
def serve(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
num_gpus: int = 1,
|
sharded: bool = False,
|
||||||
shard_directory: Optional[Path] = None,
|
shard_directory: Optional[Path] = None,
|
||||||
|
uds_path: Path = "/tmp/bloom-inference",
|
||||||
):
|
):
|
||||||
if num_gpus == 1:
|
if sharded:
|
||||||
serve(model_name, False, shard_directory)
|
assert (
|
||||||
|
shard_directory is not None
|
||||||
|
), "shard_directory must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("RANK", None) is not None
|
||||||
|
), "RANK must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("WORLD_SIZE", None) is not None
|
||||||
|
), "WORLD_SIZE must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_ADDR", None) is not None
|
||||||
|
), "MASTER_ADDR must be set when sharded is True"
|
||||||
|
assert (
|
||||||
|
os.getenv("MASTER_PORT", None) is not None
|
||||||
|
), "MASTER_PORT must be set when sharded is True"
|
||||||
|
|
||||||
else:
|
server.serve(model_name, sharded, uds_path, shard_directory)
|
||||||
config = LaunchConfig(
|
|
||||||
min_nodes=1,
|
|
||||||
max_nodes=1,
|
|
||||||
nproc_per_node=num_gpus,
|
|
||||||
rdzv_backend="c10d",
|
|
||||||
max_restarts=0,
|
|
||||||
)
|
|
||||||
launch_agent(config, server.serve, [model_name, True, shard_directory])
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def prepare_weights(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sharded: bool = False,
|
shard_directory: Path,
|
||||||
shard_directory: Optional[Path] = None,
|
cache_directory: Path,
|
||||||
|
num_shard: int = 1,
|
||||||
):
|
):
|
||||||
server.serve(model_name, sharded, shard_directory)
|
prepare_weights.prepare_weights(
|
||||||
|
model_name, cache_directory, shard_directory, num_shard
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -24,6 +24,7 @@ torch.manual_seed(0)
|
||||||
class Batch:
|
class Batch:
|
||||||
batch_id: int
|
batch_id: int
|
||||||
requests: List[generate_pb2.Request]
|
requests: List[generate_pb2.Request]
|
||||||
|
all_input_lengths: List[int]
|
||||||
input_ids: Dict[str, torch.Tensor]
|
input_ids: Dict[str, torch.Tensor]
|
||||||
all_input_ids: List[torch.Tensor]
|
all_input_ids: List[torch.Tensor]
|
||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
@ -46,12 +47,12 @@ class Batch:
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
input_lengths = []
|
all_input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
input_lengths.append(r.input_length)
|
all_input_lengths.append(r.input_length)
|
||||||
next_token_choosers.append(
|
next_token_choosers.append(
|
||||||
NextTokenChooser(
|
NextTokenChooser(
|
||||||
temperature=r.parameters.temperature,
|
temperature=r.parameters.temperature,
|
||||||
|
@ -63,17 +64,12 @@ class Batch:
|
||||||
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
||||||
|
|
||||||
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
|
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
|
||||||
# Remove padding from all_input_ids
|
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
||||||
all_input_ids = [
|
|
||||||
input_ids.squeeze(0)[-length:].unsqueeze(-1)
|
|
||||||
for length, input_ids in zip(
|
|
||||||
input_lengths, input_ids["input_ids"].split(1, dim=0)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
|
all_input_lengths=all_input_lengths,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
|
@ -91,6 +87,7 @@ class Batch:
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
||||||
requests = []
|
requests = []
|
||||||
|
all_input_lengths = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
@ -100,6 +97,7 @@ class Batch:
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
|
all_input_lengths.extend(batch.all_input_lengths)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
@ -198,6 +196,7 @@ class Batch:
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
all_input_lengths=all_input_lengths,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
|
@ -227,7 +226,10 @@ class BLOOM:
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
self.model = (
|
self.model = (
|
||||||
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype)
|
AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
.eval()
|
||||||
|
.to(self.device)
|
||||||
|
.to(dtype)
|
||||||
)
|
)
|
||||||
self.num_heads = self.model.base_model.num_heads
|
self.num_heads = self.model.base_model.num_heads
|
||||||
|
|
||||||
|
@ -253,6 +255,7 @@ class BLOOM:
|
||||||
# New input_ids for next forward
|
# New input_ids for next forward
|
||||||
next_batch_input_ids = []
|
next_batch_input_ids = []
|
||||||
next_batch_all_input_ids = []
|
next_batch_all_input_ids = []
|
||||||
|
next_all_input_lengths = []
|
||||||
|
|
||||||
next_batch_size = 0
|
next_batch_size = 0
|
||||||
next_batch_max_sequence_length = 0
|
next_batch_max_sequence_length = 0
|
||||||
|
@ -263,6 +266,7 @@ class BLOOM:
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.requests,
|
batch.requests,
|
||||||
|
batch.all_input_lengths,
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
|
@ -272,6 +276,7 @@ class BLOOM:
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
|
input_length,
|
||||||
logits,
|
logits,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
|
@ -302,8 +307,10 @@ class BLOOM:
|
||||||
next_batch_input_ids.append(next_token)
|
next_batch_input_ids.append(next_token)
|
||||||
next_batch_all_input_ids.append(all_tokens)
|
next_batch_all_input_ids.append(all_tokens)
|
||||||
next_batch_size += 1
|
next_batch_size += 1
|
||||||
|
new_input_length = input_length + 1
|
||||||
|
next_all_input_lengths.append(new_input_length)
|
||||||
next_batch_max_sequence_length = max(
|
next_batch_max_sequence_length = max(
|
||||||
next_batch_max_sequence_length, len(all_tokens)
|
next_batch_max_sequence_length, new_input_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
@ -350,6 +357,7 @@ class BLOOM:
|
||||||
next_batch = Batch(
|
next_batch = Batch(
|
||||||
batch_id=batch.batch_id,
|
batch_id=batch.batch_id,
|
||||||
requests=next_batch_requests,
|
requests=next_batch_requests,
|
||||||
|
all_input_lengths=next_all_input_lengths,
|
||||||
input_ids=next_batch_input_ids,
|
input_ids=next_batch_input_ids,
|
||||||
all_input_ids=next_batch_all_input_ids,
|
all_input_ids=next_batch_all_input_ids,
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
|
@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM):
|
||||||
if self.master:
|
if self.master:
|
||||||
# TODO @thomasw21 do some caching
|
# TODO @thomasw21 do some caching
|
||||||
shard_state_dict_paths = prepare_weights(
|
shard_state_dict_paths = prepare_weights(
|
||||||
model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size
|
model_name,
|
||||||
|
shard_directory / "cache",
|
||||||
|
shard_directory,
|
||||||
|
tp_world_size=self.world_size,
|
||||||
)
|
)
|
||||||
shard_state_dict_paths = [
|
shard_state_dict_paths = [
|
||||||
str(path.absolute()) for path in shard_state_dict_paths
|
str(path.absolute()) for path in shard_state_dict_paths
|
||||||
|
@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM):
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Logits are sharded, so we need to gather them
|
||||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||||
|
|
||||||
batch_size, vocab_shard_size = logits_shard.shape
|
batch_size, vocab_shard_size = logits_shard.shape
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
*.py
|
*.py
|
||||||
*.py-e
|
*.py-e
|
|
@ -14,15 +14,15 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
|
||||||
|
|
||||||
|
|
||||||
def match_suffix(text, suffix):
|
def match_suffix(text, suffix):
|
||||||
return text[-len(suffix):] == suffix
|
return text[-len(suffix) :] == suffix
|
||||||
|
|
||||||
|
|
||||||
def http_get(
|
def http_get(
|
||||||
url: str,
|
url: str,
|
||||||
temp_file: BinaryIO,
|
temp_file: BinaryIO,
|
||||||
*,
|
*,
|
||||||
timeout=10.0,
|
timeout=10.0,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
||||||
|
@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path):
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int):
|
def prepare_weights(
|
||||||
|
model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
|
||||||
|
):
|
||||||
save_paths = [
|
save_paths = [
|
||||||
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
||||||
for tp_rank in range(tp_world_size)
|
for tp_rank in range(tp_world_size)
|
||||||
|
@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
if model_name == "bigscience/bloom-560m":
|
if model_name == "bigscience/bloom-560m":
|
||||||
url = hf_hub_url(model_name, filename="pytorch_model.bin")
|
url = hf_hub_url(model_name, filename="pytorch_model.bin")
|
||||||
cache_download_url(url, cache_path)
|
cache_download_url(url, cache_path)
|
||||||
|
|
||||||
elif model_name == "bigscience/bloom":
|
elif model_name == "bigscience/bloom":
|
||||||
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
|
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
|
||||||
index_path = cache_download_url(url, cache_path)
|
index_path = cache_download_url(url, cache_path)
|
||||||
|
@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
index = json.load(f)
|
index = json.load(f)
|
||||||
|
|
||||||
# Get unique file names
|
# Get unique file names
|
||||||
weight_files = list(set([filename for filename in index["weight_map"].values()]))
|
weight_files = list(
|
||||||
|
set([filename for filename in index["weight_map"].values()])
|
||||||
|
)
|
||||||
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
|
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
|
||||||
|
|
||||||
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls))
|
Parallel(n_jobs=5)(
|
||||||
|
delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model name: {model_name}")
|
raise ValueError(f"Unknown model name: {model_name}")
|
||||||
|
|
||||||
|
@ -91,14 +98,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
for state_name in keys:
|
for state_name in keys:
|
||||||
state = state_dict[state_name]
|
state = state_dict[state_name]
|
||||||
if any(
|
if any(
|
||||||
match_suffix(state_name, candidate)
|
match_suffix(state_name, candidate)
|
||||||
for candidate in [
|
for candidate in [
|
||||||
"self_attention.query_key_value.weight",
|
"self_attention.query_key_value.weight",
|
||||||
"self_attention.query_key_value.bias",
|
"self_attention.query_key_value.bias",
|
||||||
"mlp.dense_h_to_4h.weight",
|
"mlp.dense_h_to_4h.weight",
|
||||||
"mlp.dense_h_to_4h.bias",
|
"mlp.dense_h_to_4h.bias",
|
||||||
"word_embeddings.weight",
|
"word_embeddings.weight",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
output_size = state.shape[0]
|
output_size = state.shape[0]
|
||||||
assert output_size % tp_world_size == 0
|
assert output_size % tp_world_size == 0
|
||||||
|
@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
assert len(sharded_weights) == tp_world_size
|
assert len(sharded_weights) == tp_world_size
|
||||||
|
|
||||||
for tp_rank, shard in enumerate(sharded_weights):
|
for tp_rank, shard in enumerate(sharded_weights):
|
||||||
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
shards_state_dicts[tp_rank][
|
||||||
|
"transformer." + state_name
|
||||||
|
] = shard.detach().clone()
|
||||||
|
|
||||||
elif match_suffix(state_name, "lm_head.weight"):
|
elif match_suffix(state_name, "lm_head.weight"):
|
||||||
output_size = state.shape[0]
|
output_size = state.shape[0]
|
||||||
|
@ -120,11 +129,11 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
||||||
|
|
||||||
elif any(
|
elif any(
|
||||||
match_suffix(state_name, candidate)
|
match_suffix(state_name, candidate)
|
||||||
for candidate in [
|
for candidate in [
|
||||||
"self_attention.dense.weight",
|
"self_attention.dense.weight",
|
||||||
"mlp.dense_4h_to_h.weight",
|
"mlp.dense_4h_to_h.weight",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
input_size = state.shape[1]
|
input_size = state.shape[1]
|
||||||
assert input_size % tp_world_size == 0
|
assert input_size % tp_world_size == 0
|
||||||
|
@ -132,23 +141,31 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
sharded_weights = torch.split(state, block_size, dim=1)
|
sharded_weights = torch.split(state, block_size, dim=1)
|
||||||
assert len(sharded_weights) == tp_world_size
|
assert len(sharded_weights) == tp_world_size
|
||||||
for tp_rank, shard in enumerate(sharded_weights):
|
for tp_rank, shard in enumerate(sharded_weights):
|
||||||
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone()
|
shards_state_dicts[tp_rank][
|
||||||
|
"transformer." + state_name
|
||||||
|
] = shard.detach().clone()
|
||||||
|
|
||||||
elif any(
|
elif any(
|
||||||
match_suffix(state_name, candidate)
|
match_suffix(state_name, candidate)
|
||||||
for candidate in [
|
for candidate in [
|
||||||
"self_attention.dense.bias",
|
"self_attention.dense.bias",
|
||||||
"mlp.dense_4h_to_h.bias",
|
"mlp.dense_4h_to_h.bias",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone()
|
shards_state_dicts[0][
|
||||||
|
"transformer." + state_name
|
||||||
|
] = state.detach().clone()
|
||||||
for tp_rank in range(1, tp_world_size):
|
for tp_rank in range(1, tp_world_size):
|
||||||
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state)
|
shards_state_dicts[tp_rank][
|
||||||
|
"transformer." + state_name
|
||||||
|
] = torch.zeros_like(state)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# We duplicate parameters across tp ranks
|
# We duplicate parameters across tp ranks
|
||||||
for tp_rank in range(tp_world_size):
|
for tp_rank in range(tp_world_size):
|
||||||
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone()
|
shards_state_dicts[tp_rank][
|
||||||
|
"transformer." + state_name
|
||||||
|
] = state.detach().clone()
|
||||||
|
|
||||||
del state_dict[state_name] # delete key from state_dict
|
del state_dict[state_name] # delete key from state_dict
|
||||||
del state # delete tensor
|
del state # delete tensor
|
||||||
|
@ -156,7 +173,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
|
|
||||||
# we save state_dict
|
# we save state_dict
|
||||||
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
||||||
zip(save_paths, shards_state_dicts)
|
zip(save_paths, shards_state_dicts)
|
||||||
):
|
):
|
||||||
save_paths.append(save_path)
|
save_paths.append(save_path)
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
|
||||||
torch.save(shard_state_dict, save_path)
|
torch.save(shard_state_dict, save_path)
|
||||||
|
|
||||||
return save_paths
|
return save_paths
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument("--model-name", required=True, type=str)
|
|
||||||
parser.add_argument("--cache-path", required=True, type=str)
|
|
||||||
parser.add_argument("--save-path", required=True, type=str)
|
|
||||||
parser.add_argument("--world-size", required=True, type=int)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)
|
|
||||||
|
|
|
@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
batch=next_batch.to_pb() if next_batch else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def GenerateUntilFinished(self, request, context):
|
|
||||||
batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)
|
|
||||||
|
|
||||||
generated_texts = []
|
def serve(
|
||||||
while not generated_texts:
|
model_name: str,
|
||||||
generated_texts, next_batch = self.model.generate_token(batch)
|
sharded: bool,
|
||||||
batch = next_batch
|
uds_path: Path,
|
||||||
self.cache.set(next_batch)
|
shard_directory: Optional[Path] = None,
|
||||||
|
):
|
||||||
return generate_pb2.GenerateUntilFinishedResponse(
|
|
||||||
generated_texts=[
|
|
||||||
generated_text.to_pb() for generated_text in generated_texts
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def GenerateUntilFinishedWithCache(self, request, context):
|
|
||||||
if len(request.batches) == 0:
|
|
||||||
raise ValueError("Must provide at least one batch")
|
|
||||||
|
|
||||||
batches = []
|
|
||||||
for batch_pb in request.batches:
|
|
||||||
batch = self.cache.pop(batch_pb.id)
|
|
||||||
if batch is None:
|
|
||||||
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
|
|
||||||
batches.append(batch)
|
|
||||||
|
|
||||||
if len(batches) > 1:
|
|
||||||
batch = Batch.concatenate(batches)
|
|
||||||
else:
|
|
||||||
batch = batches[0]
|
|
||||||
|
|
||||||
generated_texts = []
|
|
||||||
while not generated_texts:
|
|
||||||
generated_texts, next_batch = self.model.generate_token(batch)
|
|
||||||
batch = next_batch
|
|
||||||
self.cache.set(next_batch)
|
|
||||||
|
|
||||||
return generate_pb2.GenerateUntilFinishedWithCacheResponse(
|
|
||||||
generated_texts=[
|
|
||||||
generated_text.to_pb() for generated_text in generated_texts
|
|
||||||
],
|
|
||||||
batch=next_batch.to_pb() if next_batch else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def serve(model_name, sharded, shard_directory):
|
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
shard_directory: Optional[Path] = None,
|
shard_directory: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix:///tmp/bloom-inference-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
if shard_directory is None:
|
if shard_directory is None:
|
||||||
raise ValueError("shard_directory must be set when sharded is True")
|
raise ValueError("shard_directory must be set when sharded is True")
|
||||||
model = BLOOMSharded(model_name, shard_directory)
|
model = BLOOMSharded(model_name, shard_directory)
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(rank) for rank in range(model.world_size)
|
unix_socket_template.format(uds_path, rank)
|
||||||
|
for rank in range(model.world_size)
|
||||||
]
|
]
|
||||||
local_url = unix_socket_template.format(model.rank)
|
local_url = server_urls[model.rank]
|
||||||
else:
|
else:
|
||||||
model = BLOOM(model_name)
|
model = BLOOM(model_name)
|
||||||
local_url = unix_socket_template.format(0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
server = aio.server()
|
server = aio.server()
|
||||||
|
@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory):
|
||||||
server.add_insecure_port(local_url)
|
server.add_insecure_port(local_url)
|
||||||
await server.start()
|
await server.start()
|
||||||
print("Server started at {}".format(local_url))
|
print("Server started at {}".format(local_url))
|
||||||
await server.wait_for_termination()
|
try:
|
||||||
|
await server.wait_for_termination()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Signal received. Shutting down")
|
||||||
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_name, sharded, shard_directory))
|
asyncio.run(serve_inner(model_name, sharded, shard_directory))
|
||||||
|
|
|
@ -82,7 +82,6 @@ def initialize_torch_distributed():
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
timeout=timedelta(seconds=60),
|
timeout=timedelta(seconds=60),
|
||||||
init_method="tcp://localhost:6000",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
||||||
|
|
|
@ -205,7 +205,7 @@ python-versions = ">=3.7"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2"
|
content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
accelerate = [
|
accelerate = [
|
||||||
|
|
|
@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app'
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
protobuf = "^4.21.7"
|
protobuf = "^4.21.7"
|
||||||
grpcio = "^1.49.1"
|
grpcio = "^1.49.1"
|
||||||
torch = "^1.12.1"
|
|
||||||
typer = "^0.6.1"
|
typer = "^0.6.1"
|
||||||
grpcio-reflection = "^1.49.1"
|
grpcio-reflection = "^1.49.1"
|
||||||
accelerate = "^0.12.0"
|
accelerate = "^0.12.0"
|
||||||
|
|
Loading…
Reference in New Issue