feat(server): Support all AutoModelForCausalLM on a best effort basis
This commit is contained in:
parent
09674e6df9
commit
3cf6368c77
|
@ -28,9 +28,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyhow"
|
name = "anyhow"
|
||||||
version = "1.0.65"
|
version = "1.0.66"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-stream"
|
name = "async-stream"
|
||||||
|
@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.5.16"
|
version = "0.5.17"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
|
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
@ -114,9 +114,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.2.8"
|
version = "0.2.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
|
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
@ -130,9 +130,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64"
|
name = "base64"
|
||||||
version = "0.13.0"
|
version = "0.13.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
|
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
|
@ -149,21 +149,6 @@ dependencies = [
|
||||||
"generic-array",
|
"generic-array",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "bloom-inference-client"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"futures",
|
|
||||||
"prost",
|
|
||||||
"thiserror",
|
|
||||||
"tokio",
|
|
||||||
"tonic",
|
|
||||||
"tonic-build",
|
|
||||||
"tower",
|
|
||||||
"tracing",
|
|
||||||
"tracing-error",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumpalo"
|
name = "bumpalo"
|
||||||
version = "3.11.1"
|
version = "3.11.1"
|
||||||
|
@ -255,9 +240,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.0.17"
|
version = "4.0.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
|
checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atty",
|
"atty",
|
||||||
"bitflags",
|
"bitflags",
|
||||||
|
@ -270,9 +255,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_derive"
|
name = "clap_derive"
|
||||||
version = "4.0.13"
|
version = "4.0.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
|
checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"heck 0.4.0",
|
"heck 0.4.0",
|
||||||
"proc-macro-error",
|
"proc-macro-error",
|
||||||
|
@ -532,14 +517,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "filetime"
|
name = "filetime"
|
||||||
version = "0.2.17"
|
version = "0.2.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c"
|
checksum = "4b9663d381d07ae25dc88dbdf27df458faa83a9b25336bcac83d5e452b5fc9d3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"redox_syscall",
|
"redox_syscall",
|
||||||
"windows-sys 0.36.1",
|
"windows-sys 0.42.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -600,9 +585,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures"
|
name = "futures"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
|
checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
@ -615,9 +600,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050"
|
checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
|
@ -625,15 +610,15 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-core"
|
name = "futures-core"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
|
checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-executor"
|
name = "futures-executor"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab"
|
checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
@ -642,15 +627,15 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-io"
|
name = "futures-io"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68"
|
checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-macro"
|
name = "futures-macro"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17"
|
checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -659,21 +644,21 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56"
|
checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-task"
|
name = "futures-task"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1"
|
checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-util"
|
||||||
version = "0.3.24"
|
version = "0.3.25"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90"
|
checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
@ -699,9 +684,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
version = "0.2.7"
|
version = "0.2.8"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6"
|
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
|
@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.3.14"
|
version = "0.3.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be"
|
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"fnv",
|
"fnv",
|
||||||
|
@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.135"
|
version = "0.2.137"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
|
checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lock_api"
|
name = "lock_api"
|
||||||
|
@ -992,9 +977,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "macro_rules_attribute"
|
name = "macro_rules_attribute"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "258c86475e1616d6f2d8f5227cfaabd3dae1f6d5388b9597df8a199d4497aba7"
|
checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"macro_rules_attribute-proc_macro",
|
"macro_rules_attribute-proc_macro",
|
||||||
"paste",
|
"paste",
|
||||||
|
@ -1002,9 +987,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "macro_rules_attribute-proc_macro"
|
name = "macro_rules_attribute-proc_macro"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
|
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
|
@ -1050,14 +1035,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mio"
|
name = "mio"
|
||||||
version = "0.8.4"
|
version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf"
|
checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||||
"windows-sys 0.36.1",
|
"windows-sys 0.42.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openssl-sys"
|
name = "openssl-sys"
|
||||||
version = "0.9.76"
|
version = "0.9.77"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce"
|
checksum = "b03b84c3b2d099b81f0953422b4d4ad58761589d0229b5506356afca05a3670a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"cc",
|
"cc",
|
||||||
|
@ -1213,9 +1198,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "os_str_bytes"
|
name = "os_str_bytes"
|
||||||
version = "6.3.0"
|
version = "6.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
|
checksum = "3baf96e39c5359d2eb0dd6ccb42c62b91d9678aa68160d261b9e0ccbf9e9dea9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "overload"
|
name = "overload"
|
||||||
|
@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pkg-config"
|
name = "pkg-config"
|
||||||
version = "0.3.25"
|
version = "0.3.26"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
|
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
|
@ -1602,18 +1587,18 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.145"
|
version = "1.0.147"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b"
|
checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_derive"
|
name = "serde_derive"
|
||||||
version = "1.0.145"
|
version = "1.0.147"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c"
|
checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -1622,9 +1607,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_json"
|
name = "serde_json"
|
||||||
version = "1.0.86"
|
version = "1.0.87"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
|
checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"itoa",
|
"itoa",
|
||||||
"ryu",
|
"ryu",
|
||||||
|
@ -1739,9 +1724,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.102"
|
version = "1.0.103"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
|
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -1798,11 +1783,26 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "text-generation-client"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"futures",
|
||||||
|
"prost",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tonic",
|
||||||
|
"tonic-build",
|
||||||
|
"tower",
|
||||||
|
"tracing",
|
||||||
|
"tracing-error",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.0.17",
|
"clap 4.0.18",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
"subprocess",
|
"subprocess",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -1814,12 +1814,12 @@ name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"bloom-inference-client",
|
"clap 4.0.18",
|
||||||
"clap 4.0.17",
|
|
||||||
"futures",
|
"futures",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
|
|
@ -66,7 +66,7 @@ COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
|
||||||
|
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -22,7 +22,7 @@ run-bloom-560m-quantize:
|
||||||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
|
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
|
||||||
|
|
||||||
download-bloom:
|
download-bloom:
|
||||||
bloom-inference-server download-weights bigscience/bloom
|
text-generation-server download-weights bigscience/bloom
|
||||||
|
|
||||||
run-bloom:
|
run-bloom:
|
||||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
||||||
|
|
|
@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference.
|
||||||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||||
|
|
||||||
## Supported models
|
## Officially supported models
|
||||||
|
|
||||||
- BLOOM
|
- BLOOM
|
||||||
- BLOOM-560m
|
- BLOOM-560m
|
||||||
|
|
||||||
|
Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`.
|
||||||
|
|
||||||
## Load Tests for BLOOM
|
## Load Tests for BLOOM
|
||||||
|
|
||||||
See `k6/load_test.js`
|
See `k6/load_test.js`
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
|
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
|
||||||
name: bloom
|
name: bloom-safetensors
|
||||||
version: 1
|
version: 1
|
||||||
path: ./bloom
|
path: ./bloom-safetensors
|
||||||
type: custom_model
|
type: custom_model
|
||||||
|
|
|
@ -256,7 +256,7 @@ fn shard_manager(
|
||||||
|
|
||||||
// Process args
|
// Process args
|
||||||
let mut shard_argv = vec![
|
let mut shard_argv = vec![
|
||||||
"bloom-inference-server".to_string(),
|
"text-generation-server".to_string(),
|
||||||
"serve".to_string(),
|
"serve".to_string(),
|
||||||
model_name,
|
model_name,
|
||||||
"--uds-path".to_string(),
|
"--uds-path".to_string(),
|
||||||
|
@ -311,7 +311,7 @@ fn shard_manager(
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
if let PopenError::IoError(ref err) = err {
|
if let PopenError::IoError(ref err) = err {
|
||||||
if err.kind() == io::ErrorKind::NotFound {
|
if err.kind() == io::ErrorKind::NotFound {
|
||||||
tracing::error!("bloom-inference-server not found in PATH");
|
tracing::error!("text-generation-server not found in PATH");
|
||||||
tracing::error!("Please install it with `make install-server`")
|
tracing::error!("Please install it with `make install-server`")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||||
bloom-inference-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[package]
|
[package]
|
||||||
name = "bloom-inference-client"
|
name = "text-generation-client"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,9 @@ use crate::{Db, Entry};
|
||||||
use crate::{ErrorResponse, GenerateRequest};
|
use crate::{ErrorResponse, GenerateRequest};
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{oneshot, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use crate::InferResponse;
|
use crate::InferResponse;
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
/// Text Generation Inference webserver entrypoint
|
|
||||||
use bloom_inference_client::ShardedClient;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
/// Text Generation Inference webserver entrypoint
|
||||||
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ struct Args {
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
|
|
@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use bloom_inference_client::ShardedClient;
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
bloom_inference/__pycache__/
|
text_generation/__pycache__/
|
||||||
bloom_inference/pb/__pycache__/
|
text_generation/pb/__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.49.1 --no-cache-dir
|
pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||||
mkdir bloom_inference/pb || true
|
mkdir text_generation/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto
|
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||||
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch bloom_inference/pb/__init__.py
|
touch text_generation/pb/__init__.py
|
||||||
|
|
||||||
install-transformers:
|
install-transformers:
|
||||||
# Install specific version of transformers
|
# Install specific version of transformers
|
||||||
|
@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors
|
||||||
pip install -e . --no-cache-dir
|
pip install -e . --no-cache-dir
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded
|
python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
|
@ -1,582 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Tuple, Optional, Dict
|
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from safetensors import safe_open
|
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|
||||||
from transformers.models.bloom.parallel_layers import (
|
|
||||||
TensorParallelColumnLinear,
|
|
||||||
TensorParallelEmbedding,
|
|
||||||
TensorParallelRowLinear,
|
|
||||||
)
|
|
||||||
|
|
||||||
from bloom_inference.pb import generate_pb2
|
|
||||||
from bloom_inference.utils import (
|
|
||||||
StoppingCriteria,
|
|
||||||
NextTokenChooser,
|
|
||||||
initialize_torch_distributed,
|
|
||||||
weight_files,
|
|
||||||
download_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
from bitsandbytes.nn import Int8Params
|
|
||||||
except Exception as e:
|
|
||||||
HAS_BITS_AND_BYTES = False
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Batch:
|
|
||||||
batch_id: int
|
|
||||||
requests: List[generate_pb2.Request]
|
|
||||||
all_input_lengths: List[int]
|
|
||||||
input_ids: Dict[str, torch.Tensor]
|
|
||||||
all_input_ids: List[torch.Tensor]
|
|
||||||
next_token_choosers: List[NextTokenChooser]
|
|
||||||
stopping_criterias: List[StoppingCriteria]
|
|
||||||
size: int
|
|
||||||
max_sequence_length: int
|
|
||||||
|
|
||||||
def to_pb(self):
|
|
||||||
return generate_pb2.Batch(
|
|
||||||
id=self.batch_id,
|
|
||||||
requests=self.requests,
|
|
||||||
size=self.size,
|
|
||||||
max_sequence_length=self.max_sequence_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
|
||||||
) -> "Batch":
|
|
||||||
inputs = []
|
|
||||||
next_token_choosers = []
|
|
||||||
stopping_criterias = []
|
|
||||||
all_input_lengths = []
|
|
||||||
|
|
||||||
# Parse batch
|
|
||||||
for r in pb.requests:
|
|
||||||
inputs.append(r.inputs)
|
|
||||||
all_input_lengths.append(r.input_length)
|
|
||||||
next_token_choosers.append(
|
|
||||||
NextTokenChooser(
|
|
||||||
temperature=r.parameters.temperature,
|
|
||||||
top_k=r.parameters.top_k,
|
|
||||||
top_p=r.parameters.top_p,
|
|
||||||
do_sample=r.parameters.do_sample,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
|
||||||
|
|
||||||
input_ids = tokenizer(
|
|
||||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
|
||||||
).to(device)
|
|
||||||
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
batch_id=pb.id,
|
|
||||||
requests=pb.requests,
|
|
||||||
all_input_lengths=all_input_lengths,
|
|
||||||
input_ids=input_ids,
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
next_token_choosers=next_token_choosers,
|
|
||||||
stopping_criterias=stopping_criterias,
|
|
||||||
size=pb.size,
|
|
||||||
max_sequence_length=pb.max_sequence_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
|
||||||
# Used for padding
|
|
||||||
total_batch_size = sum(batch.size for batch in batches)
|
|
||||||
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
|
||||||
|
|
||||||
# Batch attributes
|
|
||||||
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
|
||||||
requests = []
|
|
||||||
all_input_lengths = []
|
|
||||||
all_input_ids = []
|
|
||||||
next_token_choosers = []
|
|
||||||
stopping_criterias = []
|
|
||||||
|
|
||||||
# Used for slicing correctly inside the tensors
|
|
||||||
# Equivalent to a cumsum on batch sizes
|
|
||||||
start_index = 0
|
|
||||||
for i, batch in enumerate(batches):
|
|
||||||
requests.extend(batch.requests)
|
|
||||||
all_input_lengths.extend(batch.all_input_lengths)
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
|
||||||
next_token_choosers.extend(batch.next_token_choosers)
|
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
|
||||||
|
|
||||||
# Slicing end index for this batch
|
|
||||||
end_index = start_index + batch.size
|
|
||||||
|
|
||||||
# We only concatenate batches that did at least one step
|
|
||||||
if batch.input_ids["input_ids"].shape[1] > 1:
|
|
||||||
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
|
||||||
|
|
||||||
# Initialize tensors
|
|
||||||
if i == 0:
|
|
||||||
input_ids["input_ids"] = torch.empty(
|
|
||||||
(total_batch_size, 1),
|
|
||||||
dtype=batch.input_ids["input_ids"].dtype,
|
|
||||||
device=batch.input_ids["input_ids"].device,
|
|
||||||
)
|
|
||||||
input_ids["attention_mask"] = torch.zeros(
|
|
||||||
(total_batch_size, max_sequence_length),
|
|
||||||
dtype=batch.input_ids["attention_mask"].dtype,
|
|
||||||
device=batch.input_ids["attention_mask"].device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# input_ids["input_ids"] is always of shape [batch_size, 1]
|
|
||||||
# We do not need to pad it
|
|
||||||
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
|
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
|
||||||
input_ids["attention_mask"][
|
|
||||||
start_index:end_index, -batch.max_sequence_length :
|
|
||||||
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
|
|
||||||
|
|
||||||
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
|
||||||
past_keys = past[0]
|
|
||||||
past_values = past[1]
|
|
||||||
|
|
||||||
_, head_dim, padded_sequence_length = past_keys.shape
|
|
||||||
|
|
||||||
# Reshape the tensors to make slicing easier
|
|
||||||
past_keys = past_keys.view(
|
|
||||||
batch.size, -1, head_dim, padded_sequence_length
|
|
||||||
)
|
|
||||||
past_values = past_values.view(
|
|
||||||
batch.size, -1, padded_sequence_length, head_dim
|
|
||||||
)
|
|
||||||
num_heads = past_keys.shape[1]
|
|
||||||
|
|
||||||
# Initialize tensors
|
|
||||||
# This will run only once per layer
|
|
||||||
if j == len(input_ids["past_key_values"]):
|
|
||||||
padded_past_keys = torch.zeros(
|
|
||||||
(
|
|
||||||
total_batch_size,
|
|
||||||
num_heads,
|
|
||||||
head_dim,
|
|
||||||
max_sequence_length - 1,
|
|
||||||
),
|
|
||||||
dtype=past_keys.dtype,
|
|
||||||
device=past_keys.device,
|
|
||||||
)
|
|
||||||
padded_past_values = torch.zeros(
|
|
||||||
(
|
|
||||||
total_batch_size,
|
|
||||||
num_heads,
|
|
||||||
max_sequence_length - 1,
|
|
||||||
head_dim,
|
|
||||||
),
|
|
||||||
dtype=past_values.dtype,
|
|
||||||
device=past_values.device,
|
|
||||||
)
|
|
||||||
input_ids["past_key_values"].append(
|
|
||||||
[padded_past_keys, padded_past_values]
|
|
||||||
)
|
|
||||||
|
|
||||||
# We slice the past keys and values to remove the padding from previous batches
|
|
||||||
input_ids["past_key_values"][j][0][
|
|
||||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
|
||||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
|
||||||
|
|
||||||
input_ids["past_key_values"][j][1][
|
|
||||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
|
||||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
|
||||||
|
|
||||||
# If we are on the last batch, we need to reshape the tensors
|
|
||||||
if (i + 1) == len(batches):
|
|
||||||
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
|
|
||||||
j
|
|
||||||
][0].view(total_batch_size * num_heads, head_dim, -1)
|
|
||||||
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
|
|
||||||
j
|
|
||||||
][1].view(total_batch_size * num_heads, -1, head_dim)
|
|
||||||
|
|
||||||
start_index += batch.size
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
batch_id=batches[0].batch_id,
|
|
||||||
requests=requests,
|
|
||||||
all_input_lengths=all_input_lengths,
|
|
||||||
input_ids=input_ids,
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
next_token_choosers=next_token_choosers,
|
|
||||||
stopping_criterias=stopping_criterias,
|
|
||||||
size=total_batch_size,
|
|
||||||
max_sequence_length=max_sequence_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GeneratedText:
|
|
||||||
request: generate_pb2.Request
|
|
||||||
output: str
|
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
|
||||||
return generate_pb2.GeneratedText(request=self.request, output=self.output)
|
|
||||||
|
|
||||||
|
|
||||||
class BLOOM:
|
|
||||||
def __init__(self, model_name: str):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.device = torch.device("cuda")
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
|
||||||
self.model = (
|
|
||||||
AutoModelForCausalLM.from_pretrained(model_name)
|
|
||||||
.eval()
|
|
||||||
.to(self.device)
|
|
||||||
.to(dtype)
|
|
||||||
)
|
|
||||||
self.num_heads = self.model.base_model.num_heads
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
|
|
||||||
# Model Forward
|
|
||||||
return self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_token(
|
|
||||||
self, batch: Batch
|
|
||||||
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.forward(**batch.input_ids)
|
|
||||||
|
|
||||||
# List of indices to cache
|
|
||||||
next_batch_keep_indices = []
|
|
||||||
next_batch_past_keep_indices = []
|
|
||||||
|
|
||||||
# New input_ids for next forward
|
|
||||||
next_batch_input_ids = []
|
|
||||||
next_batch_all_input_ids = []
|
|
||||||
next_all_input_lengths = []
|
|
||||||
|
|
||||||
next_batch_size = 0
|
|
||||||
next_batch_max_sequence_length = 0
|
|
||||||
|
|
||||||
# Finished requests
|
|
||||||
generated_texts: List[GeneratedText] = []
|
|
||||||
|
|
||||||
# Zipped iterator
|
|
||||||
iterator = zip(
|
|
||||||
batch.requests,
|
|
||||||
batch.all_input_lengths,
|
|
||||||
outputs.logits,
|
|
||||||
batch.next_token_choosers,
|
|
||||||
batch.stopping_criterias,
|
|
||||||
batch.all_input_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For each member of the batch
|
|
||||||
for i, (
|
|
||||||
request,
|
|
||||||
input_length,
|
|
||||||
logits,
|
|
||||||
next_token_chooser,
|
|
||||||
stopping_criteria,
|
|
||||||
all_tokens,
|
|
||||||
) in enumerate(iterator):
|
|
||||||
# Select next token
|
|
||||||
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
|
||||||
|
|
||||||
# Append next token to all tokens
|
|
||||||
all_tokens = torch.cat([all_tokens, next_token])
|
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
|
||||||
if stopping_criteria(all_tokens):
|
|
||||||
# Decode all tokens
|
|
||||||
output = self.tokenizer.decode(
|
|
||||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
|
||||||
)
|
|
||||||
# Add to the list of finished generations with the original request
|
|
||||||
generated_texts.append(GeneratedText(request, output))
|
|
||||||
# add to the next batch
|
|
||||||
else:
|
|
||||||
next_batch_keep_indices.append(i)
|
|
||||||
# past_key_values is of shape [batch_size * num_heads, ...]
|
|
||||||
# so we need to take into account the `num_heads` stride here
|
|
||||||
next_batch_past_keep_indices.extend(
|
|
||||||
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
|
|
||||||
)
|
|
||||||
next_batch_input_ids.append(next_token)
|
|
||||||
next_batch_all_input_ids.append(all_tokens)
|
|
||||||
next_batch_size += 1
|
|
||||||
new_input_length = input_length + 1
|
|
||||||
next_all_input_lengths.append(new_input_length)
|
|
||||||
next_batch_max_sequence_length = max(
|
|
||||||
next_batch_max_sequence_length, new_input_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
|
||||||
if not next_batch_keep_indices:
|
|
||||||
return generated_texts, None
|
|
||||||
|
|
||||||
# If we finished at least one generation
|
|
||||||
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
|
|
||||||
if generated_texts:
|
|
||||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
|
||||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
|
|
||||||
next_batch_keep_indices
|
|
||||||
]
|
|
||||||
next_batch_input_ids["past_key_values"] = [
|
|
||||||
(
|
|
||||||
keys[next_batch_past_keep_indices],
|
|
||||||
values[next_batch_past_keep_indices],
|
|
||||||
)
|
|
||||||
for keys, values in outputs["past_key_values"]
|
|
||||||
]
|
|
||||||
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
|
||||||
next_batch_next_token_choosers = [
|
|
||||||
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
|
||||||
]
|
|
||||||
next_batch_stopping_criterias = [
|
|
||||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
|
|
||||||
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
|
|
||||||
next_batch_requests = batch.requests
|
|
||||||
next_batch_next_token_choosers = batch.next_token_choosers
|
|
||||||
next_batch_stopping_criterias = batch.stopping_criterias
|
|
||||||
|
|
||||||
# Update attention_mask with padding as we added a new token to input_ids
|
|
||||||
next_batch_input_ids["attention_mask"] = torch.cat(
|
|
||||||
[
|
|
||||||
next_batch_input_ids["attention_mask"],
|
|
||||||
torch.ones((next_batch_size, 1)).to(self.device),
|
|
||||||
],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
next_batch = Batch(
|
|
||||||
batch_id=batch.batch_id,
|
|
||||||
requests=next_batch_requests,
|
|
||||||
all_input_lengths=next_all_input_lengths,
|
|
||||||
input_ids=next_batch_input_ids,
|
|
||||||
all_input_ids=next_batch_all_input_ids,
|
|
||||||
next_token_choosers=next_batch_next_token_choosers,
|
|
||||||
stopping_criterias=next_batch_stopping_criterias,
|
|
||||||
size=next_batch_size,
|
|
||||||
max_sequence_length=next_batch_max_sequence_length,
|
|
||||||
)
|
|
||||||
return generated_texts, next_batch
|
|
||||||
|
|
||||||
|
|
||||||
class BLOOMSharded(BLOOM):
|
|
||||||
def __init__(self, model_name: str, quantize: bool = False):
|
|
||||||
super(BLOOM, self).__init__()
|
|
||||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
|
||||||
self.master = self.rank == 0
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
self.device = torch.device(f"cuda:{self.rank}")
|
|
||||||
dtype = torch.float16
|
|
||||||
else:
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_name, slow_but_exact=False, tp_parallel=True
|
|
||||||
)
|
|
||||||
config.pad_token_id = 3
|
|
||||||
self.num_heads = config.n_head // self.process_group.size()
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
|
||||||
# in PyTorch 1.12 and later.
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
# Only download weights for small models
|
|
||||||
if self.master and model_name == "bigscience/bloom-560m":
|
|
||||||
download_weights(model_name)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
filenames = weight_files(model_name)
|
|
||||||
|
|
||||||
with init_empty_weights():
|
|
||||||
model = AutoModelForCausalLM.from_config(config)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
self.load_weights(
|
|
||||||
model,
|
|
||||||
filenames,
|
|
||||||
quantize=quantize,
|
|
||||||
device=self.device,
|
|
||||||
rank=self.rank,
|
|
||||||
world_size=self.world_size,
|
|
||||||
)
|
|
||||||
self.model = model.eval().to(dtype)
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_weights(
|
|
||||||
model,
|
|
||||||
filenames: List[str],
|
|
||||||
quantize: bool,
|
|
||||||
device: torch.device,
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
):
|
|
||||||
parameters = dict(model.named_parameters())
|
|
||||||
for file in filenames:
|
|
||||||
with safe_open(
|
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
|
||||||
) as f:
|
|
||||||
for name in f.keys():
|
|
||||||
full_name = f"transformer.{name}"
|
|
||||||
|
|
||||||
module_name, param_name = full_name.rsplit(".", 1)
|
|
||||||
module = model.get_submodule(module_name)
|
|
||||||
current_tensor = parameters[full_name]
|
|
||||||
|
|
||||||
slice_ = f.get_slice(name)
|
|
||||||
|
|
||||||
if isinstance(module, TensorParallelColumnLinear):
|
|
||||||
if param_name == "weight":
|
|
||||||
size = slice_.get_shape()[0]
|
|
||||||
block_size = size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
tensor = slice_[start:stop]
|
|
||||||
tensor = tensor.transpose(1, 0)
|
|
||||||
else:
|
|
||||||
size = slice_.get_shape()[0]
|
|
||||||
block_size = size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
tensor = slice_[start:stop]
|
|
||||||
elif isinstance(module, TensorParallelRowLinear):
|
|
||||||
if param_name == "weight":
|
|
||||||
size = slice_.get_shape()[1]
|
|
||||||
block_size = size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
tensor = slice_[:, start:stop]
|
|
||||||
tensor = tensor.transpose(1, 0)
|
|
||||||
else:
|
|
||||||
tensor = slice_[:]
|
|
||||||
# XXX: Hack for Rowlinear to add the bias only once.
|
|
||||||
if rank != 0:
|
|
||||||
tensor = torch.zeros_like(tensor)
|
|
||||||
elif isinstance(module, TensorParallelEmbedding):
|
|
||||||
size = slice_.get_shape()[0]
|
|
||||||
block_size = size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
tensor = slice_[start:stop]
|
|
||||||
else:
|
|
||||||
tensor = slice_[:]
|
|
||||||
|
|
||||||
if current_tensor.shape != tensor.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
tensor = tensor.contiguous()
|
|
||||||
|
|
||||||
if quantize:
|
|
||||||
if not HAS_BITS_AND_BYTES:
|
|
||||||
raise ImportError(
|
|
||||||
"bitsandbytes is not available on your machine"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
type(module)
|
|
||||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
|
||||||
and param_name == "weight"
|
|
||||||
):
|
|
||||||
tensor = Int8Params(
|
|
||||||
tensor.transpose(1, 0),
|
|
||||||
has_fp16_weights=False,
|
|
||||||
requires_grad=False,
|
|
||||||
).to(device)
|
|
||||||
state = bnb.MatmulLtState()
|
|
||||||
state.threshold = 6.0
|
|
||||||
state.has_fp16_weights = False
|
|
||||||
state.memory_efficient_backward = False
|
|
||||||
state.use_pool = True
|
|
||||||
state.CB = tensor.CB
|
|
||||||
state.SCB = tensor.SCB
|
|
||||||
tensor.CB = None
|
|
||||||
tensor.SCB = None
|
|
||||||
|
|
||||||
def replace_linear(state, in_features, out_features):
|
|
||||||
def linear(input, weight, bias):
|
|
||||||
size_out = input.size()[:-1] + (out_features,)
|
|
||||||
input = input.view(-1, in_features)
|
|
||||||
out = torch.empty(
|
|
||||||
size_out, device=input.device, dtype=input.dtype
|
|
||||||
)
|
|
||||||
out = bnb.matmul(
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
out=out.view(-1, out_features),
|
|
||||||
state=state,
|
|
||||||
threshold=state.threshold,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
if state.CB is not None:
|
|
||||||
# we converted 8-bit row major to turing/ampere format
|
|
||||||
# in the first inference pass
|
|
||||||
# we no longer need the row-major weight
|
|
||||||
del state.CB
|
|
||||||
weight.data = state.CxB
|
|
||||||
|
|
||||||
return out.view(size_out)
|
|
||||||
|
|
||||||
return linear
|
|
||||||
|
|
||||||
module.linear = replace_linear(
|
|
||||||
state, module.in_features, module.out_features
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
tensor = tensor.to(device)
|
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
|
||||||
if name == "word_embeddings.weight":
|
|
||||||
model.lm_head._parameters["weight"] = tensor
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
|
|
||||||
outputs = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Logits are sharded, so we need to gather them
|
|
||||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
|
||||||
|
|
||||||
batch_size, vocab_shard_size = logits_shard.shape
|
|
||||||
vocab_size = self.world_size * vocab_shard_size
|
|
||||||
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
|
|
||||||
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
|
|
||||||
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
|
|
||||||
|
|
||||||
outputs.logits = logits
|
|
||||||
return outputs
|
|
|
@ -1,11 +1,11 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "bloom-inference"
|
name = "text-generation"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "BLOOM Inference Python gRPC Server"
|
description = "BLOOM Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
bloom-inference-server = 'bloom_inference.cli:app'
|
text-generation-server = 'text_generation.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
|
@ -17,6 +17,9 @@ accelerate = "^0.12.0"
|
||||||
joblib = "^1.2.0"
|
joblib = "^1.2.0"
|
||||||
bitsandbytes = "^0.35.1"
|
bitsandbytes = "^0.35.1"
|
||||||
|
|
||||||
|
[tool.poetry.extras]
|
||||||
|
bnb = ["bitsandbytes"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
grpcio-tools = "^1.49.1"
|
grpcio-tools = "^1.49.1"
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from bloom_inference.model import Batch
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from text_generation.models.types import Batch
|
||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self):
|
def __init__(self):
|
|
@ -3,7 +3,7 @@ import typer
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from bloom_inference import server, utils
|
from text_generation import server, utils
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ def serve(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
uds_path: Path = "/tmp/bloom-inference",
|
uds_path: Path = "/tmp/text-generation",
|
||||||
):
|
):
|
||||||
if sharded:
|
if sharded:
|
||||||
assert (
|
assert (
|
||||||
|
@ -35,8 +35,9 @@ def serve(
|
||||||
@app.command()
|
@app.command()
|
||||||
def download_weights(
|
def download_weights(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
extension: str = ".safetensors",
|
||||||
):
|
):
|
||||||
utils.download_weights(model_name)
|
utils.download_weights(model_name, extension)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -0,0 +1,22 @@
|
||||||
|
from text_generation.models.model import Model
|
||||||
|
from text_generation.models.bloom import BLOOMSharded
|
||||||
|
|
||||||
|
__all__ = ["Model", "BLOOMSharded"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||||
|
|
||||||
|
if model_name.startswith("bigscience/bloom"):
|
||||||
|
if sharded:
|
||||||
|
return BLOOMSharded(model_name, quantize)
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not supported for non-sharded BLOOM")
|
||||||
|
return Model(model_name)
|
||||||
|
else:
|
||||||
|
if sharded:
|
||||||
|
raise ValueError("sharded is only supported for BLOOM")
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("Quantization is only supported for BLOOM models")
|
||||||
|
|
||||||
|
return Model(model_name)
|
|
@ -0,0 +1,231 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from safetensors import safe_open
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
|
from transformers.models.bloom.parallel_layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
from text_generation.models import Model
|
||||||
|
from text_generation.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
download_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
HAS_BITS_AND_BYTES = True
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from bitsandbytes.nn import Int8Params
|
||||||
|
except Exception as e:
|
||||||
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
class BLOOMSharded(Model):
|
||||||
|
def __init__(self, model_name: str, quantize: bool = False):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||||
|
self.master = self.rank == 0
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.device(f"cuda:{self.rank}")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_name, slow_but_exact=False, tp_parallel=True
|
||||||
|
)
|
||||||
|
config.pad_token_id = 3
|
||||||
|
self.num_heads = config.n_head // self.process_group.size()
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
|
# in PyTorch 1.12 and later.
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Only download weights for small models
|
||||||
|
if self.master and model_name == "bigscience/bloom-560m":
|
||||||
|
download_weights(model_name, extension=".safetensors")
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_name, extension=".safetensors")
|
||||||
|
|
||||||
|
with init_empty_weights():
|
||||||
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
self.load_weights(
|
||||||
|
model,
|
||||||
|
filenames,
|
||||||
|
quantize=quantize,
|
||||||
|
device=self.device,
|
||||||
|
rank=self.rank,
|
||||||
|
world_size=self.world_size,
|
||||||
|
)
|
||||||
|
self.model = model.eval().to(dtype)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_weights(
|
||||||
|
model,
|
||||||
|
filenames: List[str],
|
||||||
|
quantize: bool,
|
||||||
|
device: torch.device,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
):
|
||||||
|
parameters = dict(model.named_parameters())
|
||||||
|
for file in filenames:
|
||||||
|
with safe_open(
|
||||||
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
|
) as f:
|
||||||
|
for name in f.keys():
|
||||||
|
full_name = f"transformer.{name}"
|
||||||
|
|
||||||
|
module_name, param_name = full_name.rsplit(".", 1)
|
||||||
|
module = model.get_submodule(module_name)
|
||||||
|
current_tensor = parameters[full_name]
|
||||||
|
|
||||||
|
slice_ = f.get_slice(name)
|
||||||
|
|
||||||
|
if isinstance(module, TensorParallelColumnLinear):
|
||||||
|
if param_name == "weight":
|
||||||
|
size = slice_.get_shape()[0]
|
||||||
|
block_size = size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
tensor = slice_[start:stop]
|
||||||
|
tensor = tensor.transpose(1, 0)
|
||||||
|
else:
|
||||||
|
size = slice_.get_shape()[0]
|
||||||
|
block_size = size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
tensor = slice_[start:stop]
|
||||||
|
elif isinstance(module, TensorParallelRowLinear):
|
||||||
|
if param_name == "weight":
|
||||||
|
size = slice_.get_shape()[1]
|
||||||
|
block_size = size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
tensor = slice_[:, start:stop]
|
||||||
|
tensor = tensor.transpose(1, 0)
|
||||||
|
else:
|
||||||
|
tensor = slice_[:]
|
||||||
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||||||
|
if rank != 0:
|
||||||
|
tensor = torch.zeros_like(tensor)
|
||||||
|
elif isinstance(module, TensorParallelEmbedding):
|
||||||
|
size = slice_.get_shape()[0]
|
||||||
|
block_size = size // world_size
|
||||||
|
start = rank * block_size
|
||||||
|
stop = (rank + 1) * block_size
|
||||||
|
tensor = slice_[start:stop]
|
||||||
|
else:
|
||||||
|
tensor = slice_[:]
|
||||||
|
|
||||||
|
if current_tensor.shape != tensor.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor = tensor.contiguous()
|
||||||
|
|
||||||
|
if quantize:
|
||||||
|
if not HAS_BITS_AND_BYTES:
|
||||||
|
raise ImportError(
|
||||||
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
|
"or you don't have a GPU.\n"
|
||||||
|
"You can install it with `pip install bitsandbytes`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
type(module)
|
||||||
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||||
|
and param_name == "weight"
|
||||||
|
):
|
||||||
|
tensor = Int8Params(
|
||||||
|
tensor.transpose(1, 0),
|
||||||
|
has_fp16_weights=False,
|
||||||
|
requires_grad=False,
|
||||||
|
).to(device)
|
||||||
|
state = bnb.MatmulLtState()
|
||||||
|
state.threshold = 6.0
|
||||||
|
state.has_fp16_weights = False
|
||||||
|
state.memory_efficient_backward = False
|
||||||
|
state.use_pool = True
|
||||||
|
state.CB = tensor.CB
|
||||||
|
state.SCB = tensor.SCB
|
||||||
|
tensor.CB = None
|
||||||
|
tensor.SCB = None
|
||||||
|
|
||||||
|
def replace_linear(state, in_features, out_features):
|
||||||
|
def linear(input, weight, bias):
|
||||||
|
size_out = input.size()[:-1] + (out_features,)
|
||||||
|
input = input.view(-1, in_features)
|
||||||
|
out = torch.empty(
|
||||||
|
size_out, device=input.device, dtype=input.dtype
|
||||||
|
)
|
||||||
|
out = bnb.matmul(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
out=out.view(-1, out_features),
|
||||||
|
state=state,
|
||||||
|
threshold=state.threshold,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.CB is not None:
|
||||||
|
# we converted 8-bit row major to turing/ampere format
|
||||||
|
# in the first inference pass
|
||||||
|
# we no longer need the row-major weight
|
||||||
|
del state.CB
|
||||||
|
weight.data = state.CxB
|
||||||
|
|
||||||
|
return out.view(size_out)
|
||||||
|
|
||||||
|
return linear
|
||||||
|
|
||||||
|
module.linear = replace_linear(
|
||||||
|
state, module.in_features, module.out_features
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
tensor = tensor.to(device)
|
||||||
|
|
||||||
|
module._parameters[param_name] = tensor
|
||||||
|
if name == "word_embeddings.weight":
|
||||||
|
model.lm_head._parameters["weight"] = tensor
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
|
||||||
|
outputs = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logits are sharded, so we need to gather them
|
||||||
|
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||||
|
|
||||||
|
batch_size, vocab_shard_size = logits_shard.shape
|
||||||
|
vocab_size = self.world_size * vocab_shard_size
|
||||||
|
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
|
||||||
|
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
|
||||||
|
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
|
||||||
|
|
||||||
|
outputs.logits = logits
|
||||||
|
return outputs
|
|
@ -0,0 +1,166 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from text_generation.models.types import Batch, GeneratedText
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
|
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name, torch_dtype=dtype, device_map="auto"
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
self.num_heads = self.model.config.num_attention_heads
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids, attention_mask, past_key_values: Optional = None
|
||||||
|
) -> CausalLMOutputWithPast:
|
||||||
|
# Model Forward
|
||||||
|
return self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_token(
|
||||||
|
self, batch: Batch
|
||||||
|
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
||||||
|
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||||
|
context_manager = (
|
||||||
|
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||||
|
)
|
||||||
|
with context_manager():
|
||||||
|
outputs = self.forward(**batch.input_ids)
|
||||||
|
|
||||||
|
# List of indices to cache
|
||||||
|
next_batch_keep_indices = []
|
||||||
|
next_batch_past_keep_indices = []
|
||||||
|
|
||||||
|
# New input_ids for next forward
|
||||||
|
next_batch_input_ids = []
|
||||||
|
next_batch_all_input_ids = []
|
||||||
|
next_all_input_lengths = []
|
||||||
|
|
||||||
|
next_batch_size = 0
|
||||||
|
next_batch_max_sequence_length = 0
|
||||||
|
|
||||||
|
# Finished requests
|
||||||
|
generated_texts: List[GeneratedText] = []
|
||||||
|
|
||||||
|
# Zipped iterator
|
||||||
|
iterator = zip(
|
||||||
|
batch.requests,
|
||||||
|
batch.all_input_lengths,
|
||||||
|
outputs.logits,
|
||||||
|
batch.next_token_choosers,
|
||||||
|
batch.stopping_criterias,
|
||||||
|
batch.all_input_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each member of the batch
|
||||||
|
for i, (
|
||||||
|
request,
|
||||||
|
input_length,
|
||||||
|
logits,
|
||||||
|
next_token_chooser,
|
||||||
|
stopping_criteria,
|
||||||
|
all_tokens,
|
||||||
|
) in enumerate(iterator):
|
||||||
|
# Select next token
|
||||||
|
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
||||||
|
|
||||||
|
# Append next token to all tokens
|
||||||
|
all_tokens = torch.cat([all_tokens, next_token])
|
||||||
|
|
||||||
|
# Evaluate stopping criteria
|
||||||
|
if stopping_criteria(all_tokens):
|
||||||
|
# Decode all tokens
|
||||||
|
output = self.tokenizer.decode(
|
||||||
|
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||||
|
)
|
||||||
|
# Add to the list of finished generations with the original request
|
||||||
|
generated_texts.append(GeneratedText(request, output))
|
||||||
|
# add to the next batch
|
||||||
|
else:
|
||||||
|
next_batch_keep_indices.append(i)
|
||||||
|
# past_key_values is of shape [batch_size * num_heads, ...]
|
||||||
|
# so we need to take into account the `num_heads` stride here
|
||||||
|
next_batch_past_keep_indices.extend(
|
||||||
|
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
|
||||||
|
)
|
||||||
|
next_batch_input_ids.append(next_token)
|
||||||
|
next_batch_all_input_ids.append(all_tokens)
|
||||||
|
next_batch_size += 1
|
||||||
|
new_input_length = input_length + 1
|
||||||
|
next_all_input_lengths.append(new_input_length)
|
||||||
|
next_batch_max_sequence_length = max(
|
||||||
|
next_batch_max_sequence_length, new_input_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
if not next_batch_keep_indices:
|
||||||
|
return generated_texts, None
|
||||||
|
|
||||||
|
# If we finished at least one generation
|
||||||
|
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
|
||||||
|
if generated_texts:
|
||||||
|
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||||
|
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
|
||||||
|
next_batch_keep_indices
|
||||||
|
]
|
||||||
|
next_batch_input_ids["past_key_values"] = [
|
||||||
|
(
|
||||||
|
keys[next_batch_past_keep_indices],
|
||||||
|
values[next_batch_past_keep_indices],
|
||||||
|
)
|
||||||
|
for keys, values in outputs["past_key_values"]
|
||||||
|
]
|
||||||
|
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
||||||
|
next_batch_next_token_choosers = [
|
||||||
|
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
||||||
|
]
|
||||||
|
next_batch_stopping_criterias = [
|
||||||
|
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
|
||||||
|
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
|
||||||
|
next_batch_requests = batch.requests
|
||||||
|
next_batch_next_token_choosers = batch.next_token_choosers
|
||||||
|
next_batch_stopping_criterias = batch.stopping_criterias
|
||||||
|
|
||||||
|
# Update attention_mask with padding as we added a new token to input_ids
|
||||||
|
next_batch_input_ids["attention_mask"] = torch.cat(
|
||||||
|
[
|
||||||
|
next_batch_input_ids["attention_mask"],
|
||||||
|
torch.ones((next_batch_size, 1)).to(self.device),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_batch = Batch(
|
||||||
|
batch_id=batch.batch_id,
|
||||||
|
requests=next_batch_requests,
|
||||||
|
all_input_lengths=next_all_input_lengths,
|
||||||
|
input_ids=next_batch_input_ids,
|
||||||
|
all_input_ids=next_batch_all_input_ids,
|
||||||
|
next_token_choosers=next_batch_next_token_choosers,
|
||||||
|
stopping_criterias=next_batch_stopping_criterias,
|
||||||
|
size=next_batch_size,
|
||||||
|
max_sequence_length=next_batch_max_sequence_length,
|
||||||
|
)
|
||||||
|
return generated_texts, next_batch
|
|
@ -0,0 +1,206 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from text_generation.pb import generate_pb2
|
||||||
|
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Batch:
|
||||||
|
batch_id: int
|
||||||
|
requests: List[generate_pb2.Request]
|
||||||
|
all_input_lengths: List[int]
|
||||||
|
input_ids: Dict[str, torch.Tensor]
|
||||||
|
all_input_ids: List[torch.Tensor]
|
||||||
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
size: int
|
||||||
|
max_sequence_length: int
|
||||||
|
|
||||||
|
def to_pb(self):
|
||||||
|
return generate_pb2.Batch(
|
||||||
|
id=self.batch_id,
|
||||||
|
requests=self.requests,
|
||||||
|
size=self.size,
|
||||||
|
max_sequence_length=self.max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||||
|
) -> "Batch":
|
||||||
|
inputs = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
all_input_lengths = []
|
||||||
|
|
||||||
|
# Parse batch
|
||||||
|
for r in pb.requests:
|
||||||
|
inputs.append(r.inputs)
|
||||||
|
all_input_lengths.append(r.input_length)
|
||||||
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser(
|
||||||
|
temperature=r.parameters.temperature,
|
||||||
|
top_k=r.parameters.top_k,
|
||||||
|
top_p=r.parameters.top_p,
|
||||||
|
do_sample=r.parameters.do_sample,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
||||||
|
|
||||||
|
input_ids = tokenizer(
|
||||||
|
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||||
|
).to(device)
|
||||||
|
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=pb.id,
|
||||||
|
requests=pb.requests,
|
||||||
|
all_input_lengths=all_input_lengths,
|
||||||
|
input_ids=input_ids,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
size=pb.size,
|
||||||
|
max_sequence_length=pb.max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||||
|
# Used for padding
|
||||||
|
total_batch_size = sum(batch.size for batch in batches)
|
||||||
|
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
||||||
|
|
||||||
|
# Batch attributes
|
||||||
|
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
||||||
|
requests = []
|
||||||
|
all_input_lengths = []
|
||||||
|
all_input_ids = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
|
||||||
|
# Used for slicing correctly inside the tensors
|
||||||
|
# Equivalent to a cumsum on batch sizes
|
||||||
|
start_index = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
requests.extend(batch.requests)
|
||||||
|
all_input_lengths.extend(batch.all_input_lengths)
|
||||||
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
|
||||||
|
# Slicing end index for this batch
|
||||||
|
end_index = start_index + batch.size
|
||||||
|
|
||||||
|
# We only concatenate batches that did at least one step
|
||||||
|
if batch.input_ids["input_ids"].shape[1] > 1:
|
||||||
|
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
||||||
|
|
||||||
|
# Initialize tensors
|
||||||
|
if i == 0:
|
||||||
|
input_ids["input_ids"] = torch.empty(
|
||||||
|
(total_batch_size, 1),
|
||||||
|
dtype=batch.input_ids["input_ids"].dtype,
|
||||||
|
device=batch.input_ids["input_ids"].device,
|
||||||
|
)
|
||||||
|
input_ids["attention_mask"] = torch.zeros(
|
||||||
|
(total_batch_size, max_sequence_length),
|
||||||
|
dtype=batch.input_ids["attention_mask"].dtype,
|
||||||
|
device=batch.input_ids["attention_mask"].device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# input_ids["input_ids"] is always of shape [batch_size, 1]
|
||||||
|
# We do not need to pad it
|
||||||
|
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
|
||||||
|
|
||||||
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
|
input_ids["attention_mask"][
|
||||||
|
start_index:end_index, -batch.max_sequence_length :
|
||||||
|
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
|
||||||
|
|
||||||
|
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
||||||
|
past_keys = past[0]
|
||||||
|
past_values = past[1]
|
||||||
|
|
||||||
|
_, head_dim, padded_sequence_length = past_keys.shape
|
||||||
|
|
||||||
|
# Reshape the tensors to make slicing easier
|
||||||
|
past_keys = past_keys.view(
|
||||||
|
batch.size, -1, head_dim, padded_sequence_length
|
||||||
|
)
|
||||||
|
past_values = past_values.view(
|
||||||
|
batch.size, -1, padded_sequence_length, head_dim
|
||||||
|
)
|
||||||
|
num_heads = past_keys.shape[1]
|
||||||
|
|
||||||
|
# Initialize tensors
|
||||||
|
# This will run only once per layer
|
||||||
|
if j == len(input_ids["past_key_values"]):
|
||||||
|
padded_past_keys = torch.zeros(
|
||||||
|
(
|
||||||
|
total_batch_size,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
max_sequence_length - 1,
|
||||||
|
),
|
||||||
|
dtype=past_keys.dtype,
|
||||||
|
device=past_keys.device,
|
||||||
|
)
|
||||||
|
padded_past_values = torch.zeros(
|
||||||
|
(
|
||||||
|
total_batch_size,
|
||||||
|
num_heads,
|
||||||
|
max_sequence_length - 1,
|
||||||
|
head_dim,
|
||||||
|
),
|
||||||
|
dtype=past_values.dtype,
|
||||||
|
device=past_values.device,
|
||||||
|
)
|
||||||
|
input_ids["past_key_values"].append(
|
||||||
|
[padded_past_keys, padded_past_values]
|
||||||
|
)
|
||||||
|
|
||||||
|
# We slice the past keys and values to remove the padding from previous batches
|
||||||
|
input_ids["past_key_values"][j][0][
|
||||||
|
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
||||||
|
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||||
|
|
||||||
|
input_ids["past_key_values"][j][1][
|
||||||
|
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||||
|
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||||
|
|
||||||
|
# If we are on the last batch, we need to reshape the tensors
|
||||||
|
if (i + 1) == len(batches):
|
||||||
|
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
|
||||||
|
j
|
||||||
|
][0].view(total_batch_size * num_heads, head_dim, -1)
|
||||||
|
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
|
||||||
|
j
|
||||||
|
][1].view(total_batch_size * num_heads, -1, head_dim)
|
||||||
|
|
||||||
|
start_index += batch.size
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
all_input_lengths=all_input_lengths,
|
||||||
|
input_ids=input_ids,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
size=total_batch_size,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneratedText:
|
||||||
|
request: generate_pb2.Request
|
||||||
|
output: str
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||||
|
return generate_pb2.GeneratedText(request=self.request, output=self.output)
|
|
@ -5,15 +5,16 @@ from grpc import aio
|
||||||
|
|
||||||
from grpc_reflection.v1alpha import reflection
|
from grpc_reflection.v1alpha import reflection
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import List
|
||||||
|
|
||||||
from bloom_inference.cache import Cache
|
from text_generation.cache import Cache
|
||||||
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
|
from text_generation.models import Model, get_model
|
||||||
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
|
from text_generation.models.types import Batch
|
||||||
|
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.model = model
|
self.model = model
|
||||||
self.server_urls = server_urls
|
self.server_urls = server_urls
|
||||||
|
@ -78,21 +79,17 @@ def serve(
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
model = BLOOMSharded(model_name, quantize)
|
|
||||||
server_urls = [
|
server_urls = [
|
||||||
unix_socket_template.format(uds_path, rank)
|
unix_socket_template.format(uds_path, rank)
|
||||||
for rank in range(model.world_size)
|
for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||||
]
|
]
|
||||||
local_url = server_urls[model.rank]
|
local_url = server_urls[int(os.environ["RANK"])]
|
||||||
else:
|
else:
|
||||||
if quantize:
|
|
||||||
raise ValueError(
|
|
||||||
"bitsandbytes quantization is only available when running in `sharded` mode."
|
|
||||||
)
|
|
||||||
model = BLOOM(model_name)
|
|
||||||
local_url = unix_socket_template.format(uds_path, 0)
|
local_url = unix_socket_template.format(uds_path, 0)
|
||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
|
model = get_model(model_name, sharded, quantize)
|
||||||
|
|
||||||
server = aio.server()
|
server = aio.server()
|
||||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||||
TextGenerationService(model, Cache(), server_urls), server
|
TextGenerationService(model, Cache(), server_urls), server
|
|
@ -92,19 +92,17 @@ def initialize_torch_distributed():
|
||||||
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(model_name):
|
def weight_hub_files(model_name, extension=".safetensors"):
|
||||||
"""Get the safetensors filenames on the hub"""
|
"""Get the safetensors filenames on the hub"""
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
info = api.model_info(model_name)
|
info = api.model_info(model_name)
|
||||||
filenames = [
|
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
||||||
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
|
|
||||||
]
|
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def weight_files(model_name):
|
def weight_files(model_name, extension=".safetensors"):
|
||||||
"""Get the local safetensors filenames"""
|
"""Get the local safetensors filenames"""
|
||||||
filenames = weight_hub_files(model_name)
|
filenames = weight_hub_files(model_name, extension)
|
||||||
files = []
|
files = []
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
cache_file = try_to_load_from_cache(model_name, filename=filename)
|
cache_file = try_to_load_from_cache(model_name, filename=filename)
|
||||||
|
@ -112,16 +110,16 @@ def weight_files(model_name):
|
||||||
raise LocalEntryNotFoundError(
|
raise LocalEntryNotFoundError(
|
||||||
f"File {filename} of model {model_name} not found in "
|
f"File {filename} of model {model_name} not found in "
|
||||||
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
||||||
f"Please run `bloom-inference-server download-weights {model_name}` first."
|
f"Please run `text-generation-server download-weights {model_name}` first."
|
||||||
)
|
)
|
||||||
files.append(cache_file)
|
files.append(cache_file)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def download_weights(model_name):
|
def download_weights(model_name, extension=".safetensors"):
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
filenames = weight_hub_files(model_name)
|
filenames = weight_hub_files(model_name, extension)
|
||||||
|
|
||||||
download_function = partial(
|
download_function = partial(
|
||||||
hf_hub_download, repo_id=model_name, local_files_only=False
|
hf_hub_download, repo_id=model_name, local_files_only=False
|
Loading…
Reference in New Issue