feat(server): Support all AutoModelForCausalLM on a best effort basis
This commit is contained in:
parent
09674e6df9
commit
3cf6368c77
|
@ -28,9 +28,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.65"
|
||||
version = "1.0.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602"
|
||||
checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
|
@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
|||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.5.16"
|
||||
version = "0.5.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
|
||||
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
|
@ -114,9 +114,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.2.8"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
|
||||
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
|
@ -130,9 +130,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.0"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
|
@ -149,21 +149,6 @@ dependencies = [
|
|||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bloom-inference-client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"prost",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.11.1"
|
||||
|
@ -255,9 +240,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.0.17"
|
||||
version = "4.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267"
|
||||
checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags",
|
||||
|
@ -270,9 +255,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.0.13"
|
||||
version = "4.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
|
||||
checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3"
|
||||
dependencies = [
|
||||
"heck 0.4.0",
|
||||
"proc-macro-error",
|
||||
|
@ -532,14 +517,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "filetime"
|
||||
version = "0.2.17"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c"
|
||||
checksum = "4b9663d381d07ae25dc88dbdf27df458faa83a9b25336bcac83d5e452b5fc9d3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"windows-sys 0.36.1",
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -600,9 +585,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c"
|
||||
checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
|
@ -615,9 +600,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050"
|
||||
checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
|
@ -625,15 +610,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf"
|
||||
checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab"
|
||||
checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
|
@ -642,15 +627,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68"
|
||||
checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17"
|
||||
checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -659,21 +644,21 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56"
|
||||
checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1"
|
||||
checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.24"
|
||||
version = "0.3.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90"
|
||||
checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
|
@ -699,9 +684,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.7"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6"
|
||||
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
|
@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
|||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.14"
|
||||
version = "0.3.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be"
|
||||
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
|
@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
|||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.135"
|
||||
version = "0.2.137"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c"
|
||||
checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
|
@ -992,9 +977,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "macro_rules_attribute"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "258c86475e1616d6f2d8f5227cfaabd3dae1f6d5388b9597df8a199d4497aba7"
|
||||
checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
|
||||
dependencies = [
|
||||
"macro_rules_attribute-proc_macro",
|
||||
"paste",
|
||||
|
@ -1002,9 +987,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "macro_rules_attribute-proc_macro"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
|
||||
checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
|
@ -1050,14 +1035,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.4"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf"
|
||||
checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"windows-sys 0.36.1",
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
|||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.76"
|
||||
version = "0.9.77"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce"
|
||||
checksum = "b03b84c3b2d099b81f0953422b4d4ad58761589d0229b5506356afca05a3670a"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cc",
|
||||
|
@ -1213,9 +1198,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "os_str_bytes"
|
||||
version = "6.3.0"
|
||||
version = "6.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
|
||||
checksum = "3baf96e39c5359d2eb0dd6ccb42c62b91d9678aa68160d261b9e0ccbf9e9dea9"
|
||||
|
||||
[[package]]
|
||||
name = "overload"
|
||||
|
@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
|||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.25"
|
||||
version = "0.3.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
|
||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
|
@ -1602,18 +1587,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.145"
|
||||
version = "1.0.147"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b"
|
||||
checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.145"
|
||||
version = "1.0.147"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c"
|
||||
checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1622,9 +1607,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.86"
|
||||
version = "1.0.87"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074"
|
||||
checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
|
@ -1739,9 +1724,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.102"
|
||||
version = "1.0.103"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
|
||||
checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -1798,11 +1783,26 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"prost",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap 4.0.17",
|
||||
"clap 4.0.18",
|
||||
"ctrlc",
|
||||
"subprocess",
|
||||
"tracing",
|
||||
|
@ -1814,12 +1814,12 @@ name = "text-generation-router"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"bloom-inference-client",
|
||||
"clap 4.0.17",
|
||||
"clap 4.0.18",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text-generation-client",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
|
|
@ -66,7 +66,7 @@ COPY proto proto
|
|||
COPY server server
|
||||
RUN cd server && \
|
||||
make gen-server && \
|
||||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||
/opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir
|
||||
|
||||
# Install router
|
||||
COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||
|
|
2
Makefile
2
Makefile
|
@ -22,7 +22,7 @@ run-bloom-560m-quantize:
|
|||
text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize
|
||||
|
||||
download-bloom:
|
||||
bloom-inference-server download-weights bigscience/bloom
|
||||
text-generation-server download-weights bigscience/bloom
|
||||
|
||||
run-bloom:
|
||||
text-generation-launcher --model-name bigscience/bloom --num-shard 8
|
||||
|
|
|
@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference.
|
|||
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
|
||||
- 45ms per token generation for BLOOM with 8xA100 80GB
|
||||
|
||||
## Supported models
|
||||
## Officially supported models
|
||||
|
||||
- BLOOM
|
||||
- BLOOM-560m
|
||||
|
||||
Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`.
|
||||
|
||||
## Load Tests for BLOOM
|
||||
|
||||
See `k6/load_test.js`
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
$schema: https://azuremlschemas.azureedge.net/latest/model.schema.json
|
||||
name: bloom
|
||||
name: bloom-safetensors
|
||||
version: 1
|
||||
path: ./bloom
|
||||
path: ./bloom-safetensors
|
||||
type: custom_model
|
||||
|
|
|
@ -256,7 +256,7 @@ fn shard_manager(
|
|||
|
||||
// Process args
|
||||
let mut shard_argv = vec![
|
||||
"bloom-inference-server".to_string(),
|
||||
"text-generation-server".to_string(),
|
||||
"serve".to_string(),
|
||||
model_name,
|
||||
"--uds-path".to_string(),
|
||||
|
@ -311,7 +311,7 @@ fn shard_manager(
|
|||
Err(err) => {
|
||||
if let PopenError::IoError(ref err) = err {
|
||||
if err.kind() == io::ErrorKind::NotFound {
|
||||
tracing::error!("bloom-inference-server not found in PATH");
|
||||
tracing::error!("text-generation-server not found in PATH");
|
||||
tracing::error!("Please install it with `make install-server`")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ path = "src/main.rs"
|
|||
|
||||
[dependencies]
|
||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||
bloom-inference-client = { path = "client" }
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
futures = "0.3.24"
|
||||
parking_lot = "0.12.1"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[package]
|
||||
name = "bloom-inference-client"
|
||||
name = "text-generation-client"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ use crate::{Db, Entry};
|
|||
use crate::{ErrorResponse, GenerateRequest};
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
use tokio::time::Instant;
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use crate::InferResponse;
|
||||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use tokio::sync::oneshot::Sender;
|
||||
use tokio::time::Instant;
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/// Text Generation Inference webserver entrypoint
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use clap::Parser;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
/// Text Generation Inference webserver entrypoint
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
|
@ -19,7 +19,7 @@ struct Args {
|
|||
max_waiting_tokens: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
||||
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
|
|
|
@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode};
|
|||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::ShardedClient;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
use tokio::sync::Semaphore;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
bloom_inference/__pycache__/
|
||||
bloom_inference/pb/__pycache__/
|
||||
text_generation/__pycache__/
|
||||
text_generation/pb/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
gen-server:
|
||||
# Compile protos
|
||||
pip install grpcio-tools==1.49.1 --no-cache-dir
|
||||
mkdir bloom_inference/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto
|
||||
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch bloom_inference/pb/__init__.py
|
||||
mkdir text_generation/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation/pb/__init__.py
|
||||
|
||||
install-transformers:
|
||||
# Install specific version of transformers
|
||||
|
@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors
|
|||
pip install -e . --no-cache-dir
|
||||
|
||||
run-dev:
|
||||
python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded
|
||||
python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
|
@ -1,582 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers.models.bloom.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from bloom_inference.pb import generate_pb2
|
||||
from bloom_inference.utils import (
|
||||
StoppingCriteria,
|
||||
NextTokenChooser,
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except Exception as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
all_input_lengths: List[int]
|
||||
input_ids: Dict[str, torch.Tensor]
|
||||
all_input_ids: List[torch.Tensor]
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
size: int
|
||||
max_sequence_length: int
|
||||
|
||||
def to_pb(self):
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
size=self.size,
|
||||
max_sequence_length=self.max_sequence_length,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
) -> "Batch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
all_input_lengths = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
all_input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser(
|
||||
temperature=r.parameters.temperature,
|
||||
top_k=r.parameters.top_k,
|
||||
top_p=r.parameters.top_p,
|
||||
do_sample=r.parameters.do_sample,
|
||||
)
|
||||
)
|
||||
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
||||
|
||||
input_ids = tokenizer(
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||
).to(device)
|
||||
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
all_input_lengths=all_input_lengths,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
max_sequence_length=pb.max_sequence_length,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||
# Used for padding
|
||||
total_batch_size = sum(batch.size for batch in batches)
|
||||
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
||||
|
||||
# Batch attributes
|
||||
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
||||
requests = []
|
||||
all_input_lengths = []
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
all_input_lengths.extend(batch.all_input_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + batch.size
|
||||
|
||||
# We only concatenate batches that did at least one step
|
||||
if batch.input_ids["input_ids"].shape[1] > 1:
|
||||
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
||||
|
||||
# Initialize tensors
|
||||
if i == 0:
|
||||
input_ids["input_ids"] = torch.empty(
|
||||
(total_batch_size, 1),
|
||||
dtype=batch.input_ids["input_ids"].dtype,
|
||||
device=batch.input_ids["input_ids"].device,
|
||||
)
|
||||
input_ids["attention_mask"] = torch.zeros(
|
||||
(total_batch_size, max_sequence_length),
|
||||
dtype=batch.input_ids["attention_mask"].dtype,
|
||||
device=batch.input_ids["attention_mask"].device,
|
||||
)
|
||||
|
||||
# input_ids["input_ids"] is always of shape [batch_size, 1]
|
||||
# We do not need to pad it
|
||||
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
input_ids["attention_mask"][
|
||||
start_index:end_index, -batch.max_sequence_length :
|
||||
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
|
||||
|
||||
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
||||
past_keys = past[0]
|
||||
past_values = past[1]
|
||||
|
||||
_, head_dim, padded_sequence_length = past_keys.shape
|
||||
|
||||
# Reshape the tensors to make slicing easier
|
||||
past_keys = past_keys.view(
|
||||
batch.size, -1, head_dim, padded_sequence_length
|
||||
)
|
||||
past_values = past_values.view(
|
||||
batch.size, -1, padded_sequence_length, head_dim
|
||||
)
|
||||
num_heads = past_keys.shape[1]
|
||||
|
||||
# Initialize tensors
|
||||
# This will run only once per layer
|
||||
if j == len(input_ids["past_key_values"]):
|
||||
padded_past_keys = torch.zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
),
|
||||
dtype=past_keys.dtype,
|
||||
device=past_keys.device,
|
||||
)
|
||||
padded_past_values = torch.zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_sequence_length - 1,
|
||||
head_dim,
|
||||
),
|
||||
dtype=past_values.dtype,
|
||||
device=past_values.device,
|
||||
)
|
||||
input_ids["past_key_values"].append(
|
||||
[padded_past_keys, padded_past_values]
|
||||
)
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
input_ids["past_key_values"][j][0][
|
||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
input_ids["past_key_values"][j][1][
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
|
||||
# If we are on the last batch, we need to reshape the tensors
|
||||
if (i + 1) == len(batches):
|
||||
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
|
||||
j
|
||||
][0].view(total_batch_size * num_heads, head_dim, -1)
|
||||
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
|
||||
j
|
||||
][1].view(total_batch_size * num_heads, -1, head_dim)
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
all_input_lengths=all_input_lengths,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedText:
|
||||
request: generate_pb2.Request
|
||||
output: str
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(request=self.request, output=self.output)
|
||||
|
||||
|
||||
class BLOOM:
|
||||
def __init__(self, model_name: str):
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(model_name)
|
||||
.eval()
|
||||
.to(self.device)
|
||||
.to(dtype)
|
||||
)
|
||||
self.num_heads = self.model.base_model.num_heads
|
||||
|
||||
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
|
||||
# Model Forward
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
def generate_token(
|
||||
self, batch: Batch
|
||||
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
||||
with torch.inference_mode():
|
||||
outputs = self.forward(**batch.input_ids)
|
||||
|
||||
# List of indices to cache
|
||||
next_batch_keep_indices = []
|
||||
next_batch_past_keep_indices = []
|
||||
|
||||
# New input_ids for next forward
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
next_all_input_lengths = []
|
||||
|
||||
next_batch_size = 0
|
||||
next_batch_max_sequence_length = 0
|
||||
|
||||
# Finished requests
|
||||
generated_texts: List[GeneratedText] = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.all_input_lengths,
|
||||
outputs.logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_tokens,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
||||
|
||||
# Append next token to all tokens
|
||||
all_tokens = torch.cat([all_tokens, next_token])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
if stopping_criteria(all_tokens):
|
||||
# Decode all tokens
|
||||
output = self.tokenizer.decode(
|
||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||
)
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(GeneratedText(request, output))
|
||||
# add to the next batch
|
||||
else:
|
||||
next_batch_keep_indices.append(i)
|
||||
# past_key_values is of shape [batch_size * num_heads, ...]
|
||||
# so we need to take into account the `num_heads` stride here
|
||||
next_batch_past_keep_indices.extend(
|
||||
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
|
||||
)
|
||||
next_batch_input_ids.append(next_token)
|
||||
next_batch_all_input_ids.append(all_tokens)
|
||||
next_batch_size += 1
|
||||
new_input_length = input_length + 1
|
||||
next_all_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
return generated_texts, None
|
||||
|
||||
# If we finished at least one generation
|
||||
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
|
||||
if generated_texts:
|
||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
|
||||
next_batch_keep_indices
|
||||
]
|
||||
next_batch_input_ids["past_key_values"] = [
|
||||
(
|
||||
keys[next_batch_past_keep_indices],
|
||||
values[next_batch_past_keep_indices],
|
||||
)
|
||||
for keys, values in outputs["past_key_values"]
|
||||
]
|
||||
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
||||
next_batch_next_token_choosers = [
|
||||
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
||||
]
|
||||
next_batch_stopping_criterias = [
|
||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
||||
]
|
||||
else:
|
||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
|
||||
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
|
||||
next_batch_requests = batch.requests
|
||||
next_batch_next_token_choosers = batch.next_token_choosers
|
||||
next_batch_stopping_criterias = batch.stopping_criterias
|
||||
|
||||
# Update attention_mask with padding as we added a new token to input_ids
|
||||
next_batch_input_ids["attention_mask"] = torch.cat(
|
||||
[
|
||||
next_batch_input_ids["attention_mask"],
|
||||
torch.ones((next_batch_size, 1)).to(self.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
next_batch = Batch(
|
||||
batch_id=batch.batch_id,
|
||||
requests=next_batch_requests,
|
||||
all_input_lengths=next_all_input_lengths,
|
||||
input_ids=next_batch_input_ids,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
)
|
||||
return generated_texts, next_batch
|
||||
|
||||
|
||||
class BLOOMSharded(BLOOM):
|
||||
def __init__(self, model_name: str, quantize: bool = False):
|
||||
super(BLOOM, self).__init__()
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device(f"cuda:{self.rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name, slow_but_exact=False, tp_parallel=True
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
self.num_heads = config.n_head // self.process_group.size()
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Only download weights for small models
|
||||
if self.master and model_name == "bigscience/bloom-560m":
|
||||
download_weights(model_name)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_name)
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=self.device,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
full_name = f"transformer.{name}"
|
||||
|
||||
module_name, param_name = full_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[full_name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
|
||||
if current_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine"
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor.transpose(1, 0),
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state, in_features, out_features):
|
||||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = torch.empty(
|
||||
size_out, device=input.device, dtype=input.dtype
|
||||
)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
out=out.view(-1, out_features),
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out.view(size_out)
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(
|
||||
state, module.in_features, module.out_features
|
||||
)
|
||||
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||
|
||||
batch_size, vocab_shard_size = logits_shard.shape
|
||||
vocab_size = self.world_size * vocab_shard_size
|
||||
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
|
||||
|
||||
outputs.logits = logits
|
||||
return outputs
|
|
@ -1,11 +1,11 @@
|
|||
[tool.poetry]
|
||||
name = "bloom-inference"
|
||||
name = "text-generation"
|
||||
version = "0.1.0"
|
||||
description = "BLOOM Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
[tool.poetry.scripts]
|
||||
bloom-inference-server = 'bloom_inference.cli:app'
|
||||
text-generation-server = 'text_generation.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
|
@ -17,6 +17,9 @@ accelerate = "^0.12.0"
|
|||
joblib = "^1.2.0"
|
||||
bitsandbytes = "^0.35.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
bnb = ["bitsandbytes"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
grpcio-tools = "^1.49.1"
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from bloom_inference.model import Batch
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation.models.types import Batch
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self):
|
|
@ -3,7 +3,7 @@ import typer
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from bloom_inference import server, utils
|
||||
from text_generation import server, utils
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
@ -13,7 +13,7 @@ def serve(
|
|||
model_name: str,
|
||||
sharded: bool = False,
|
||||
quantize: bool = False,
|
||||
uds_path: Path = "/tmp/bloom-inference",
|
||||
uds_path: Path = "/tmp/text-generation",
|
||||
):
|
||||
if sharded:
|
||||
assert (
|
||||
|
@ -35,8 +35,9 @@ def serve(
|
|||
@app.command()
|
||||
def download_weights(
|
||||
model_name: str,
|
||||
extension: str = ".safetensors",
|
||||
):
|
||||
utils.download_weights(model_name)
|
||||
utils.download_weights(model_name, extension)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -0,0 +1,22 @@
|
|||
from text_generation.models.model import Model
|
||||
from text_generation.models.bloom import BLOOMSharded
|
||||
|
||||
__all__ = ["Model", "BLOOMSharded"]
|
||||
|
||||
|
||||
def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
|
||||
|
||||
if model_name.startswith("bigscience/bloom"):
|
||||
if sharded:
|
||||
return BLOOMSharded(model_name, quantize)
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not supported for non-sharded BLOOM")
|
||||
return Model(model_name)
|
||||
else:
|
||||
if sharded:
|
||||
raise ValueError("sharded is only supported for BLOOM")
|
||||
if quantize:
|
||||
raise ValueError("Quantization is only supported for BLOOM models")
|
||||
|
||||
return Model(model_name)
|
|
@ -0,0 +1,231 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers.models.bloom.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except Exception as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class BLOOMSharded(Model):
|
||||
def __init__(self, model_name: str, quantize: bool = False):
|
||||
super(Model, self).__init__()
|
||||
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||||
self.master = self.rank == 0
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device(f"cuda:{self.rank}")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name, slow_but_exact=False, tp_parallel=True
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
self.num_heads = config.n_head // self.process_group.size()
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
# in PyTorch 1.12 and later.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Only download weights for small models
|
||||
if self.master and model_name == "bigscience/bloom-560m":
|
||||
download_weights(model_name, extension=".safetensors")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_name, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=self.device,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval().to(dtype)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
full_name = f"transformer.{name}"
|
||||
|
||||
module_name, param_name = full_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[full_name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
tensor = tensor.transpose(1, 0)
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
|
||||
if current_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor.transpose(1, 0),
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state, in_features, out_features):
|
||||
def linear(input, weight, bias):
|
||||
size_out = input.size()[:-1] + (out_features,)
|
||||
input = input.view(-1, in_features)
|
||||
out = torch.empty(
|
||||
size_out, device=input.device, dtype=input.dtype
|
||||
)
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
out=out.view(-1, out_features),
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out.view(size_out)
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(
|
||||
state, module.in_features, module.out_features
|
||||
)
|
||||
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(self, input_ids, attention_mask, past_key_values: Optional = None):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits_shard = outputs.logits[:, -1, :].contiguous()
|
||||
|
||||
batch_size, vocab_shard_size = logits_shard.shape
|
||||
vocab_size = self.world_size * vocab_shard_size
|
||||
logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, logits_shard, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size)
|
||||
|
||||
outputs.logits = logits
|
||||
return outputs
|
|
@ -0,0 +1,166 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from text_generation.models.types import Batch, GeneratedText
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, model_name: str):
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="auto"
|
||||
).eval()
|
||||
|
||||
self.num_heads = self.model.config.num_attention_heads
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, past_key_values: Optional = None
|
||||
) -> CausalLMOutputWithPast:
|
||||
# Model Forward
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
def generate_token(
|
||||
self, batch: Batch
|
||||
) -> Tuple[List[GeneratedText], Optional[Batch]]:
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
context_manager = (
|
||||
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
|
||||
)
|
||||
with context_manager():
|
||||
outputs = self.forward(**batch.input_ids)
|
||||
|
||||
# List of indices to cache
|
||||
next_batch_keep_indices = []
|
||||
next_batch_past_keep_indices = []
|
||||
|
||||
# New input_ids for next forward
|
||||
next_batch_input_ids = []
|
||||
next_batch_all_input_ids = []
|
||||
next_all_input_lengths = []
|
||||
|
||||
next_batch_size = 0
|
||||
next_batch_max_sequence_length = 0
|
||||
|
||||
# Finished requests
|
||||
generated_texts: List[GeneratedText] = []
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.all_input_lengths,
|
||||
outputs.logits,
|
||||
batch.next_token_choosers,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
logits,
|
||||
next_token_chooser,
|
||||
stopping_criteria,
|
||||
all_tokens,
|
||||
) in enumerate(iterator):
|
||||
# Select next token
|
||||
next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1])
|
||||
|
||||
# Append next token to all tokens
|
||||
all_tokens = torch.cat([all_tokens, next_token])
|
||||
|
||||
# Evaluate stopping criteria
|
||||
if stopping_criteria(all_tokens):
|
||||
# Decode all tokens
|
||||
output = self.tokenizer.decode(
|
||||
all_tokens.squeeze(-1), skip_special_tokens=True
|
||||
)
|
||||
# Add to the list of finished generations with the original request
|
||||
generated_texts.append(GeneratedText(request, output))
|
||||
# add to the next batch
|
||||
else:
|
||||
next_batch_keep_indices.append(i)
|
||||
# past_key_values is of shape [batch_size * num_heads, ...]
|
||||
# so we need to take into account the `num_heads` stride here
|
||||
next_batch_past_keep_indices.extend(
|
||||
[j for j in range(i * self.num_heads, (i + 1) * self.num_heads)]
|
||||
)
|
||||
next_batch_input_ids.append(next_token)
|
||||
next_batch_all_input_ids.append(all_tokens)
|
||||
next_batch_size += 1
|
||||
new_input_length = input_length + 1
|
||||
next_all_input_lengths.append(new_input_length)
|
||||
next_batch_max_sequence_length = max(
|
||||
next_batch_max_sequence_length, new_input_length
|
||||
)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if not next_batch_keep_indices:
|
||||
return generated_texts, None
|
||||
|
||||
# If we finished at least one generation
|
||||
next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)}
|
||||
if generated_texts:
|
||||
# Apply indices to attention mask, past key values and other items that need to be cached
|
||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][
|
||||
next_batch_keep_indices
|
||||
]
|
||||
next_batch_input_ids["past_key_values"] = [
|
||||
(
|
||||
keys[next_batch_past_keep_indices],
|
||||
values[next_batch_past_keep_indices],
|
||||
)
|
||||
for keys, values in outputs["past_key_values"]
|
||||
]
|
||||
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
|
||||
next_batch_next_token_choosers = [
|
||||
batch.next_token_choosers[i] for i in next_batch_keep_indices
|
||||
]
|
||||
next_batch_stopping_criterias = [
|
||||
batch.stopping_criterias[i] for i in next_batch_keep_indices
|
||||
]
|
||||
else:
|
||||
next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"]
|
||||
next_batch_input_ids["past_key_values"] = outputs["past_key_values"]
|
||||
next_batch_requests = batch.requests
|
||||
next_batch_next_token_choosers = batch.next_token_choosers
|
||||
next_batch_stopping_criterias = batch.stopping_criterias
|
||||
|
||||
# Update attention_mask with padding as we added a new token to input_ids
|
||||
next_batch_input_ids["attention_mask"] = torch.cat(
|
||||
[
|
||||
next_batch_input_ids["attention_mask"],
|
||||
torch.ones((next_batch_size, 1)).to(self.device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
next_batch = Batch(
|
||||
batch_id=batch.batch_id,
|
||||
requests=next_batch_requests,
|
||||
all_input_lengths=next_all_input_lengths,
|
||||
input_ids=next_batch_input_ids,
|
||||
all_input_ids=next_batch_all_input_ids,
|
||||
next_token_choosers=next_batch_next_token_choosers,
|
||||
stopping_criterias=next_batch_stopping_criterias,
|
||||
size=next_batch_size,
|
||||
max_sequence_length=next_batch_max_sequence_length,
|
||||
)
|
||||
return generated_texts, next_batch
|
|
@ -0,0 +1,206 @@
|
|||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria
|
||||
|
||||
|
||||
@dataclass
|
||||
class Batch:
|
||||
batch_id: int
|
||||
requests: List[generate_pb2.Request]
|
||||
all_input_lengths: List[int]
|
||||
input_ids: Dict[str, torch.Tensor]
|
||||
all_input_ids: List[torch.Tensor]
|
||||
next_token_choosers: List[NextTokenChooser]
|
||||
stopping_criterias: List[StoppingCriteria]
|
||||
size: int
|
||||
max_sequence_length: int
|
||||
|
||||
def to_pb(self):
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
size=self.size,
|
||||
max_sequence_length=self.max_sequence_length,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
) -> "Batch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
all_input_lengths = []
|
||||
|
||||
# Parse batch
|
||||
for r in pb.requests:
|
||||
inputs.append(r.inputs)
|
||||
all_input_lengths.append(r.input_length)
|
||||
next_token_choosers.append(
|
||||
NextTokenChooser(
|
||||
temperature=r.parameters.temperature,
|
||||
top_k=r.parameters.top_k,
|
||||
top_p=r.parameters.top_p,
|
||||
do_sample=r.parameters.do_sample,
|
||||
)
|
||||
)
|
||||
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
|
||||
|
||||
input_ids = tokenizer(
|
||||
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8
|
||||
).to(device)
|
||||
all_input_ids = input_ids["input_ids"].unsqueeze(-1)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
all_input_lengths=all_input_lengths,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=pb.size,
|
||||
max_sequence_length=pb.max_sequence_length,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, batches: List["Batch"]) -> "Batch":
|
||||
# Used for padding
|
||||
total_batch_size = sum(batch.size for batch in batches)
|
||||
max_sequence_length = max(batch.max_sequence_length for batch in batches)
|
||||
|
||||
# Batch attributes
|
||||
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
|
||||
requests = []
|
||||
all_input_lengths = []
|
||||
all_input_ids = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
||||
|
||||
# Used for slicing correctly inside the tensors
|
||||
# Equivalent to a cumsum on batch sizes
|
||||
start_index = 0
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
all_input_lengths.extend(batch.all_input_lengths)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
next_token_choosers.extend(batch.next_token_choosers)
|
||||
stopping_criterias.extend(batch.stopping_criterias)
|
||||
|
||||
# Slicing end index for this batch
|
||||
end_index = start_index + batch.size
|
||||
|
||||
# We only concatenate batches that did at least one step
|
||||
if batch.input_ids["input_ids"].shape[1] > 1:
|
||||
raise ValueError("Batch input_ids should be of shape (batch_size, 1)")
|
||||
|
||||
# Initialize tensors
|
||||
if i == 0:
|
||||
input_ids["input_ids"] = torch.empty(
|
||||
(total_batch_size, 1),
|
||||
dtype=batch.input_ids["input_ids"].dtype,
|
||||
device=batch.input_ids["input_ids"].device,
|
||||
)
|
||||
input_ids["attention_mask"] = torch.zeros(
|
||||
(total_batch_size, max_sequence_length),
|
||||
dtype=batch.input_ids["attention_mask"].dtype,
|
||||
device=batch.input_ids["attention_mask"].device,
|
||||
)
|
||||
|
||||
# input_ids["input_ids"] is always of shape [batch_size, 1]
|
||||
# We do not need to pad it
|
||||
input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"]
|
||||
|
||||
# We need to slice the attention mask to remove padding from previous steps
|
||||
input_ids["attention_mask"][
|
||||
start_index:end_index, -batch.max_sequence_length :
|
||||
] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :]
|
||||
|
||||
for j, past in enumerate(batch.input_ids["past_key_values"]):
|
||||
past_keys = past[0]
|
||||
past_values = past[1]
|
||||
|
||||
_, head_dim, padded_sequence_length = past_keys.shape
|
||||
|
||||
# Reshape the tensors to make slicing easier
|
||||
past_keys = past_keys.view(
|
||||
batch.size, -1, head_dim, padded_sequence_length
|
||||
)
|
||||
past_values = past_values.view(
|
||||
batch.size, -1, padded_sequence_length, head_dim
|
||||
)
|
||||
num_heads = past_keys.shape[1]
|
||||
|
||||
# Initialize tensors
|
||||
# This will run only once per layer
|
||||
if j == len(input_ids["past_key_values"]):
|
||||
padded_past_keys = torch.zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
),
|
||||
dtype=past_keys.dtype,
|
||||
device=past_keys.device,
|
||||
)
|
||||
padded_past_values = torch.zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_sequence_length - 1,
|
||||
head_dim,
|
||||
),
|
||||
dtype=past_values.dtype,
|
||||
device=past_values.device,
|
||||
)
|
||||
input_ids["past_key_values"].append(
|
||||
[padded_past_keys, padded_past_values]
|
||||
)
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
input_ids["past_key_values"][j][0][
|
||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
input_ids["past_key_values"][j][1][
|
||||
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
|
||||
# If we are on the last batch, we need to reshape the tensors
|
||||
if (i + 1) == len(batches):
|
||||
input_ids["past_key_values"][j][0] = input_ids["past_key_values"][
|
||||
j
|
||||
][0].view(total_batch_size * num_heads, head_dim, -1)
|
||||
input_ids["past_key_values"][j][1] = input_ids["past_key_values"][
|
||||
j
|
||||
][1].view(total_batch_size * num_heads, -1, head_dim)
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
all_input_lengths=all_input_lengths,
|
||||
input_ids=input_ids,
|
||||
all_input_ids=all_input_ids,
|
||||
next_token_choosers=next_token_choosers,
|
||||
stopping_criterias=stopping_criterias,
|
||||
size=total_batch_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedText:
|
||||
request: generate_pb2.Request
|
||||
output: str
|
||||
|
||||
def to_pb(self) -> generate_pb2.GeneratedText:
|
||||
return generate_pb2.GeneratedText(request=self.request, output=self.output)
|
|
@ -5,15 +5,16 @@ from grpc import aio
|
|||
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import List
|
||||
|
||||
from bloom_inference.cache import Cache
|
||||
from bloom_inference.model import BLOOM, Batch, BLOOMSharded
|
||||
from bloom_inference.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation.cache import Cache
|
||||
from text_generation.models import Model, get_model
|
||||
from text_generation.models.types import Batch
|
||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
self.cache = cache
|
||||
self.model = model
|
||||
self.server_urls = server_urls
|
||||
|
@ -78,21 +79,17 @@ def serve(
|
|||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
model = BLOOMSharded(model_name, quantize)
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
for rank in range(model.world_size)
|
||||
for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||
]
|
||||
local_url = server_urls[model.rank]
|
||||
local_url = server_urls[int(os.environ["RANK"])]
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError(
|
||||
"bitsandbytes quantization is only available when running in `sharded` mode."
|
||||
)
|
||||
model = BLOOM(model_name)
|
||||
local_url = unix_socket_template.format(uds_path, 0)
|
||||
server_urls = [local_url]
|
||||
|
||||
model = get_model(model_name, sharded, quantize)
|
||||
|
||||
server = aio.server()
|
||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||
TextGenerationService(model, Cache(), server_urls), server
|
|
@ -92,19 +92,17 @@ def initialize_torch_distributed():
|
|||
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
|
||||
|
||||
|
||||
def weight_hub_files(model_name):
|
||||
def weight_hub_files(model_name, extension=".safetensors"):
|
||||
"""Get the safetensors filenames on the hub"""
|
||||
api = HfApi()
|
||||
info = api.model_info(model_name)
|
||||
filenames = [
|
||||
s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors")
|
||||
]
|
||||
filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)]
|
||||
return filenames
|
||||
|
||||
|
||||
def weight_files(model_name):
|
||||
def weight_files(model_name, extension=".safetensors"):
|
||||
"""Get the local safetensors filenames"""
|
||||
filenames = weight_hub_files(model_name)
|
||||
filenames = weight_hub_files(model_name, extension)
|
||||
files = []
|
||||
for filename in filenames:
|
||||
cache_file = try_to_load_from_cache(model_name, filename=filename)
|
||||
|
@ -112,16 +110,16 @@ def weight_files(model_name):
|
|||
raise LocalEntryNotFoundError(
|
||||
f"File {filename} of model {model_name} not found in "
|
||||
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
||||
f"Please run `bloom-inference-server download-weights {model_name}` first."
|
||||
f"Please run `text-generation-server download-weights {model_name}` first."
|
||||
)
|
||||
files.append(cache_file)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def download_weights(model_name):
|
||||
def download_weights(model_name, extension=".safetensors"):
|
||||
"""Download the safetensors files from the hub"""
|
||||
filenames = weight_hub_files(model_name)
|
||||
filenames = weight_hub_files(model_name, extension)
|
||||
|
||||
download_function = partial(
|
||||
hf_hub_download, repo_id=model_name, local_files_only=False
|
Loading…
Reference in New Issue