diff --git a/Cargo.lock b/Cargo.lock index 551f7aeb..33d75f0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,9 +28,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98161a4e3e2184da77bb14f02184cdd111e83bbbcc9979dfee3c44b9a85f5602" +checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6" [[package]] name = "async-stream" @@ -83,9 +83,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.16" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043" +checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" dependencies = [ "async-trait", "axum-core", @@ -114,9 +114,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b" +checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc" dependencies = [ "async-trait", "bytes", @@ -130,9 +130,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "bitflags" @@ -149,21 +149,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "bloom-inference-client" -version = "0.1.0" -dependencies = [ - "futures", - "prost", - "thiserror", - "tokio", - "tonic", - "tonic-build", - "tower", - "tracing", - "tracing-error", -] - [[package]] name = "bumpalo" version = "3.11.1" @@ -255,9 +240,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.0.17" +version = "4.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06badb543e734a2d6568e19a40af66ed5364360b9226184926f89d229b4b4267" +checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b" dependencies = [ "atty", "bitflags", @@ -270,9 +255,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.0.13" +version = "4.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad" +checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3" dependencies = [ "heck 0.4.0", "proc-macro-error", @@ -532,14 +517,14 @@ dependencies = [ [[package]] name = "filetime" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94a7bbaa59354bc20dd75b67f23e2797b4490e9d6928203fb105c79e448c86c" +checksum = "4b9663d381d07ae25dc88dbdf27df458faa83a9b25336bcac83d5e452b5fc9d3" dependencies = [ "cfg-if", "libc", "redox_syscall", - "windows-sys 0.36.1", + "windows-sys 0.42.0", ] [[package]] @@ -600,9 +585,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f21eda599937fba36daeb58a22e8f5cee2d14c4a17b5b7739c7c8e5e3b8230c" +checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" dependencies = [ "futures-channel", "futures-core", @@ -615,9 +600,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bdd20c28fadd505d0fd6712cdfcb0d4b5648baf45faef7f852afb2399bb050" +checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" dependencies = [ "futures-core", "futures-sink", @@ -625,15 +610,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e5aa3de05362c3fb88de6531e6296e85cde7739cccad4b9dfeeb7f6ebce56bf" +checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" [[package]] name = "futures-executor" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff63c23854bee61b6e9cd331d523909f238fc7636290b96826e9cfa5faa00ab" +checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" dependencies = [ "futures-core", "futures-task", @@ -642,15 +627,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf4d2a7a308fd4578637c0b17c7e1c7ba127b8f6ba00b29f717e9655d85eb68" +checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" [[package]] name = "futures-macro" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd15d1c7456c04dbdf7e88bcd69760d74f3a798d6444e16974b505b0e62f17" +checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" dependencies = [ "proc-macro2", "quote", @@ -659,21 +644,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b20ba5a92e727ba30e72834706623d94ac93a725410b6a6b6fbc1b07f7ba56" +checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9" [[package]] name = "futures-task" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6508c467c73851293f390476d4491cf4d227dbabcd4170f3bb6044959b294f1" +checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea" [[package]] name = "futures-util" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44fb6cb1be61cc1d2e43b262516aafcf63b241cffdb1d3fa115f91d9c7b09c90" +checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" dependencies = [ "futures-channel", "futures-core", @@ -699,9 +684,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", "libc", @@ -716,9 +701,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" [[package]] name = "h2" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca32592cf21ac7ccab1825cd87f6c9b3d9022c44d086172ed0966bec8af30be" +checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" dependencies = [ "bytes", "fnv", @@ -967,9 +952,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.135" +version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68783febc7782c6c5cb401fbda4de5a9898be1762314da0bb2c10ced61f18b0c" +checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" [[package]] name = "lock_api" @@ -992,9 +977,9 @@ dependencies = [ [[package]] name = "macro_rules_attribute" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "258c86475e1616d6f2d8f5227cfaabd3dae1f6d5388b9597df8a199d4497aba7" +checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862" dependencies = [ "macro_rules_attribute-proc_macro", "paste", @@ -1002,9 +987,9 @@ dependencies = [ [[package]] name = "macro_rules_attribute-proc_macro" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea" +checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" [[package]] name = "matchit" @@ -1050,14 +1035,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" +checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.36.1", + "windows-sys 0.42.0", ] [[package]] @@ -1200,9 +1185,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.76" +version = "0.9.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5230151e44c0f05157effb743e8d517472843121cf9243e8b81393edb5acd9ce" +checksum = "b03b84c3b2d099b81f0953422b4d4ad58761589d0229b5506356afca05a3670a" dependencies = [ "autocfg", "cc", @@ -1213,9 +1198,9 @@ dependencies = [ [[package]] name = "os_str_bytes" -version = "6.3.0" +version = "6.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" +checksum = "3baf96e39c5359d2eb0dd6ccb42c62b91d9678aa68160d261b9e0ccbf9e9dea9" [[package]] name = "overload" @@ -1302,9 +1287,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" [[package]] name = "ppv-lite86" @@ -1602,18 +1587,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.145" +version = "1.0.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b" +checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.145" +version = "1.0.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fa1584d3d1bcacd84c277a0dfe21f5b0f6accf4a23d04d4c6d61f1af522b4c" +checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852" dependencies = [ "proc-macro2", "quote", @@ -1622,9 +1607,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41feea4228a6f1cd09ec7a3593a682276702cd67b5273544757dae23c096f074" +checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45" dependencies = [ "itoa", "ryu", @@ -1739,9 +1724,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.102" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" +checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" dependencies = [ "proc-macro2", "quote", @@ -1798,11 +1783,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "text-generation-client" +version = "0.1.0" +dependencies = [ + "futures", + "prost", + "thiserror", + "tokio", + "tonic", + "tonic-build", + "tower", + "tracing", + "tracing-error", +] + [[package]] name = "text-generation-launcher" version = "0.1.0" dependencies = [ - "clap 4.0.17", + "clap 4.0.18", "ctrlc", "subprocess", "tracing", @@ -1814,12 +1814,12 @@ name = "text-generation-router" version = "0.1.0" dependencies = [ "axum", - "bloom-inference-client", - "clap 4.0.17", + "clap 4.0.18", "futures", "parking_lot", "serde", "serde_json", + "text-generation-client", "thiserror", "tokenizers", "tokio", diff --git a/Dockerfile b/Dockerfile index 08561b68..a2bf199a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -66,7 +66,7 @@ COPY proto proto COPY server server RUN cd server && \ make gen-server && \ - /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir + /opt/miniconda/envs/text-generation/bin/pip install ".[bnb]" --no-cache-dir # Install router COPY --from=router-builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router diff --git a/Makefile b/Makefile index a52ed2b1..d427ff87 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ run-bloom-560m-quantize: text-generation-launcher --model-name bigscience/bloom-560m --num-shard 2 --quantize download-bloom: - bloom-inference-server download-weights bigscience/bloom + text-generation-server download-weights bigscience/bloom run-bloom: text-generation-launcher --model-name bigscience/bloom --num-shard 8 diff --git a/README.md b/README.md index e51d8c81..2d2d49d3 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,13 @@ A Rust and gRPC server for large language models text generation inference. - [Safetensors](https://github.com/huggingface/safetensors) weight loading - 45ms per token generation for BLOOM with 8xA100 80GB -## Supported models +## Officially supported models - BLOOM - BLOOM-560m +Other models are supported on a best-effort basis using `AutoModelForCausalLM.from_pretrained(, torch_dtype=torch.float16, device_map="auto")`. + ## Load Tests for BLOOM See `k6/load_test.js` diff --git a/aml/model.yaml b/aml/model.yaml index e4f1ded2..bd490f1a 100644 --- a/aml/model.yaml +++ b/aml/model.yaml @@ -1,5 +1,5 @@ $schema: https://azuremlschemas.azureedge.net/latest/model.schema.json -name: bloom +name: bloom-safetensors version: 1 -path: ./bloom +path: ./bloom-safetensors type: custom_model diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 3e82516a..f94dd589 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -256,7 +256,7 @@ fn shard_manager( // Process args let mut shard_argv = vec![ - "bloom-inference-server".to_string(), + "text-generation-server".to_string(), "serve".to_string(), model_name, "--uds-path".to_string(), @@ -311,7 +311,7 @@ fn shard_manager( Err(err) => { if let PopenError::IoError(ref err) = err { if err.kind() == io::ErrorKind::NotFound { - tracing::error!("bloom-inference-server not found in PATH"); + tracing::error!("text-generation-server not found in PATH"); tracing::error!("Please install it with `make install-server`") } } diff --git a/router/Cargo.toml b/router/Cargo.toml index 5820c138..da9518bf 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -14,7 +14,7 @@ path = "src/main.rs" [dependencies] axum = { version = "0.5.16", features = ["json", "serde_json"] } -bloom-inference-client = { path = "client" } +text-generation-client = { path = "client" } clap = { version = "4.0.15", features = ["derive", "env"] } futures = "0.3.24" parking_lot = "0.12.1" diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index 633f82a9..fdd32494 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "bloom-inference-client" +name = "text-generation-client" version = "0.1.0" edition = "2021" diff --git a/router/src/batcher.rs b/router/src/batcher.rs index f131bf99..e381986a 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -3,9 +3,9 @@ use crate::{Db, Entry}; use crate::{ErrorResponse, GenerateRequest}; use axum::http::StatusCode; use axum::Json; -use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use std::future::Future; use std::sync::Arc; +use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient}; use thiserror::Error; use tokio::sync::{oneshot, Notify}; use tokio::time::Instant; diff --git a/router/src/db.rs b/router/src/db.rs index 76a08ae0..af36614e 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,10 +1,10 @@ use crate::InferResponse; /// This code is massively inspired by Tokio mini-redis use crate::{GenerateParameters, GenerateRequest}; -use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use parking_lot::Mutex; use std::collections::BTreeMap; use std::sync::Arc; +use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request}; use tokio::sync::oneshot::Sender; use tokio::time::Instant; diff --git a/router/src/main.rs b/router/src/main.rs index ea7ebd12..b24ec4c9 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,7 +1,7 @@ -/// Text Generation Inference webserver entrypoint -use bloom_inference_client::ShardedClient; use clap::Parser; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +/// Text Generation Inference webserver entrypoint +use text_generation_client::ShardedClient; use text_generation_router::server; use tokenizers::Tokenizer; @@ -19,7 +19,7 @@ struct Args { max_waiting_tokens: usize, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "/tmp/bloom-inference-0", long, env)] + #[clap(default_value = "/tmp/text-generation-0", long, env)] master_shard_uds_path: String, #[clap(default_value = "bigscience/bloom", long, env)] tokenizer_name: String, diff --git a/router/src/server.rs b/router/src/server.rs index 5698f4ec..d31c9494 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -6,9 +6,9 @@ use axum::http::{HeaderMap, StatusCode}; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::{Json, Router}; -use bloom_inference_client::ShardedClient; use std::net::SocketAddr; use std::sync::Arc; +use text_generation_client::ShardedClient; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::Semaphore; diff --git a/server/.gitignore b/server/.gitignore index 88bbea8f..5758ba92 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -1,7 +1,7 @@ # Byte-compiled / optimized / DLL files __pycache__/ -bloom_inference/__pycache__/ -bloom_inference/pb/__pycache__/ +text_generation/__pycache__/ +text_generation/pb/__pycache__/ *.py[cod] *$py.class diff --git a/server/Makefile b/server/Makefile index 9c7a2c43..4fa966e2 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,10 +1,10 @@ gen-server: # Compile protos pip install grpcio-tools==1.49.1 --no-cache-dir - mkdir bloom_inference/pb || true - python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto - find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; - touch bloom_inference/pb/__init__.py + mkdir text_generation/pb || true + python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto + find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch text_generation/pb/__init__.py install-transformers: # Install specific version of transformers @@ -36,4 +36,4 @@ install: gen-server install-torch install-transformers install-safetensors pip install -e . --no-cache-dir run-dev: - python -m torch.distributed.run --nproc_per_node=2 bloom_inference/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file + python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file diff --git a/server/bloom_inference/model.py b/server/bloom_inference/model.py deleted file mode 100644 index a0450052..00000000 --- a/server/bloom_inference/model.py +++ /dev/null @@ -1,582 +0,0 @@ -import torch -import torch.distributed - -from dataclasses import dataclass -from typing import List, Tuple, Optional, Dict - -from accelerate import init_empty_weights -from safetensors import safe_open -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -from transformers.models.bloom.parallel_layers import ( - TensorParallelColumnLinear, - TensorParallelEmbedding, - TensorParallelRowLinear, -) - -from bloom_inference.pb import generate_pb2 -from bloom_inference.utils import ( - StoppingCriteria, - NextTokenChooser, - initialize_torch_distributed, - weight_files, - download_weights, -) - -HAS_BITS_AND_BYTES = True -try: - import bitsandbytes as bnb - from bitsandbytes.nn import Int8Params -except Exception as e: - HAS_BITS_AND_BYTES = False - -torch.manual_seed(0) - - -@dataclass -class Batch: - batch_id: int - requests: List[generate_pb2.Request] - all_input_lengths: List[int] - input_ids: Dict[str, torch.Tensor] - all_input_ids: List[torch.Tensor] - next_token_choosers: List[NextTokenChooser] - stopping_criterias: List[StoppingCriteria] - size: int - max_sequence_length: int - - def to_pb(self): - return generate_pb2.Batch( - id=self.batch_id, - requests=self.requests, - size=self.size, - max_sequence_length=self.max_sequence_length, - ) - - @classmethod - def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device - ) -> "Batch": - inputs = [] - next_token_choosers = [] - stopping_criterias = [] - all_input_lengths = [] - - # Parse batch - for r in pb.requests: - inputs.append(r.inputs) - all_input_lengths.append(r.input_length) - next_token_choosers.append( - NextTokenChooser( - temperature=r.parameters.temperature, - top_k=r.parameters.top_k, - top_p=r.parameters.top_p, - do_sample=r.parameters.do_sample, - ) - ) - stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) - - input_ids = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 - ).to(device) - all_input_ids = input_ids["input_ids"].unsqueeze(-1) - - return cls( - batch_id=pb.id, - requests=pb.requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, - all_input_ids=all_input_ids, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - size=pb.size, - max_sequence_length=pb.max_sequence_length, - ) - - @classmethod - def concatenate(cls, batches: List["Batch"]) -> "Batch": - # Used for padding - total_batch_size = sum(batch.size for batch in batches) - max_sequence_length = max(batch.max_sequence_length for batch in batches) - - # Batch attributes - input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} - requests = [] - all_input_lengths = [] - all_input_ids = [] - next_token_choosers = [] - stopping_criterias = [] - - # Used for slicing correctly inside the tensors - # Equivalent to a cumsum on batch sizes - start_index = 0 - for i, batch in enumerate(batches): - requests.extend(batch.requests) - all_input_lengths.extend(batch.all_input_lengths) - all_input_ids.extend(batch.all_input_ids) - next_token_choosers.extend(batch.next_token_choosers) - stopping_criterias.extend(batch.stopping_criterias) - - # Slicing end index for this batch - end_index = start_index + batch.size - - # We only concatenate batches that did at least one step - if batch.input_ids["input_ids"].shape[1] > 1: - raise ValueError("Batch input_ids should be of shape (batch_size, 1)") - - # Initialize tensors - if i == 0: - input_ids["input_ids"] = torch.empty( - (total_batch_size, 1), - dtype=batch.input_ids["input_ids"].dtype, - device=batch.input_ids["input_ids"].device, - ) - input_ids["attention_mask"] = torch.zeros( - (total_batch_size, max_sequence_length), - dtype=batch.input_ids["attention_mask"].dtype, - device=batch.input_ids["attention_mask"].device, - ) - - # input_ids["input_ids"] is always of shape [batch_size, 1] - # We do not need to pad it - input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] - - # We need to slice the attention mask to remove padding from previous steps - input_ids["attention_mask"][ - start_index:end_index, -batch.max_sequence_length : - ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] - - for j, past in enumerate(batch.input_ids["past_key_values"]): - past_keys = past[0] - past_values = past[1] - - _, head_dim, padded_sequence_length = past_keys.shape - - # Reshape the tensors to make slicing easier - past_keys = past_keys.view( - batch.size, -1, head_dim, padded_sequence_length - ) - past_values = past_values.view( - batch.size, -1, padded_sequence_length, head_dim - ) - num_heads = past_keys.shape[1] - - # Initialize tensors - # This will run only once per layer - if j == len(input_ids["past_key_values"]): - padded_past_keys = torch.zeros( - ( - total_batch_size, - num_heads, - head_dim, - max_sequence_length - 1, - ), - dtype=past_keys.dtype, - device=past_keys.device, - ) - padded_past_values = torch.zeros( - ( - total_batch_size, - num_heads, - max_sequence_length - 1, - head_dim, - ), - dtype=past_values.dtype, - device=past_values.device, - ) - input_ids["past_key_values"].append( - [padded_past_keys, padded_past_values] - ) - - # We slice the past keys and values to remove the padding from previous batches - input_ids["past_key_values"][j][0][ - start_index:end_index, :, :, -(batch.max_sequence_length - 1) : - ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] - - input_ids["past_key_values"][j][1][ - start_index:end_index, :, -(batch.max_sequence_length - 1) :, : - ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] - - # If we are on the last batch, we need to reshape the tensors - if (i + 1) == len(batches): - input_ids["past_key_values"][j][0] = input_ids["past_key_values"][ - j - ][0].view(total_batch_size * num_heads, head_dim, -1) - input_ids["past_key_values"][j][1] = input_ids["past_key_values"][ - j - ][1].view(total_batch_size * num_heads, -1, head_dim) - - start_index += batch.size - - return cls( - batch_id=batches[0].batch_id, - requests=requests, - all_input_lengths=all_input_lengths, - input_ids=input_ids, - all_input_ids=all_input_ids, - next_token_choosers=next_token_choosers, - stopping_criterias=stopping_criterias, - size=total_batch_size, - max_sequence_length=max_sequence_length, - ) - - -@dataclass -class GeneratedText: - request: generate_pb2.Request - output: str - - def to_pb(self) -> generate_pb2.GeneratedText: - return generate_pb2.GeneratedText(request=self.request, output=self.output) - - -class BLOOM: - def __init__(self, model_name: str): - if torch.cuda.is_available(): - self.device = torch.device("cuda") - dtype = torch.bfloat16 - else: - self.device = torch.device("cpu") - dtype = torch.float32 - - self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - self.model = ( - AutoModelForCausalLM.from_pretrained(model_name) - .eval() - .to(self.device) - .to(dtype) - ) - self.num_heads = self.model.base_model.num_heads - - def forward(self, input_ids, attention_mask, past_key_values: Optional = None): - # Model Forward - return self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - def generate_token( - self, batch: Batch - ) -> Tuple[List[GeneratedText], Optional[Batch]]: - with torch.inference_mode(): - outputs = self.forward(**batch.input_ids) - - # List of indices to cache - next_batch_keep_indices = [] - next_batch_past_keep_indices = [] - - # New input_ids for next forward - next_batch_input_ids = [] - next_batch_all_input_ids = [] - next_all_input_lengths = [] - - next_batch_size = 0 - next_batch_max_sequence_length = 0 - - # Finished requests - generated_texts: List[GeneratedText] = [] - - # Zipped iterator - iterator = zip( - batch.requests, - batch.all_input_lengths, - outputs.logits, - batch.next_token_choosers, - batch.stopping_criterias, - batch.all_input_ids, - ) - - # For each member of the batch - for i, ( - request, - input_length, - logits, - next_token_chooser, - stopping_criteria, - all_tokens, - ) in enumerate(iterator): - # Select next token - next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) - - # Append next token to all tokens - all_tokens = torch.cat([all_tokens, next_token]) - - # Evaluate stopping criteria - if stopping_criteria(all_tokens): - # Decode all tokens - output = self.tokenizer.decode( - all_tokens.squeeze(-1), skip_special_tokens=True - ) - # Add to the list of finished generations with the original request - generated_texts.append(GeneratedText(request, output)) - # add to the next batch - else: - next_batch_keep_indices.append(i) - # past_key_values is of shape [batch_size * num_heads, ...] - # so we need to take into account the `num_heads` stride here - next_batch_past_keep_indices.extend( - [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)] - ) - next_batch_input_ids.append(next_token) - next_batch_all_input_ids.append(all_tokens) - next_batch_size += 1 - new_input_length = input_length + 1 - next_all_input_lengths.append(new_input_length) - next_batch_max_sequence_length = max( - next_batch_max_sequence_length, new_input_length - ) - - # We finished all generations in the batch; there is no next batch - if not next_batch_keep_indices: - return generated_texts, None - - # If we finished at least one generation - next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} - if generated_texts: - # Apply indices to attention mask, past key values and other items that need to be cached - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ - next_batch_keep_indices - ] - next_batch_input_ids["past_key_values"] = [ - ( - keys[next_batch_past_keep_indices], - values[next_batch_past_keep_indices], - ) - for keys, values in outputs["past_key_values"] - ] - next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] - next_batch_next_token_choosers = [ - batch.next_token_choosers[i] for i in next_batch_keep_indices - ] - next_batch_stopping_criterias = [ - batch.stopping_criterias[i] for i in next_batch_keep_indices - ] - else: - next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] - next_batch_input_ids["past_key_values"] = outputs["past_key_values"] - next_batch_requests = batch.requests - next_batch_next_token_choosers = batch.next_token_choosers - next_batch_stopping_criterias = batch.stopping_criterias - - # Update attention_mask with padding as we added a new token to input_ids - next_batch_input_ids["attention_mask"] = torch.cat( - [ - next_batch_input_ids["attention_mask"], - torch.ones((next_batch_size, 1)).to(self.device), - ], - dim=1, - ) - - next_batch = Batch( - batch_id=batch.batch_id, - requests=next_batch_requests, - all_input_lengths=next_all_input_lengths, - input_ids=next_batch_input_ids, - all_input_ids=next_batch_all_input_ids, - next_token_choosers=next_batch_next_token_choosers, - stopping_criterias=next_batch_stopping_criterias, - size=next_batch_size, - max_sequence_length=next_batch_max_sequence_length, - ) - return generated_texts, next_batch - - -class BLOOMSharded(BLOOM): - def __init__(self, model_name: str, quantize: bool = False): - super(BLOOM, self).__init__() - self.process_group, self.rank, self.world_size = initialize_torch_distributed() - self.master = self.rank == 0 - if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{self.rank}") - dtype = torch.float16 - else: - self.device = torch.device("cpu") - dtype = torch.float32 - - self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - - config = AutoConfig.from_pretrained( - model_name, slow_but_exact=False, tp_parallel=True - ) - config.pad_token_id = 3 - self.num_heads = config.n_head // self.process_group.size() - - # The flag below controls whether to allow TF32 on matmul. This flag defaults to False - # in PyTorch 1.12 and later. - torch.backends.cuda.matmul.allow_tf32 = True - - # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. - torch.backends.cudnn.allow_tf32 = True - - # Only download weights for small models - if self.master and model_name == "bigscience/bloom-560m": - download_weights(model_name) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_name) - - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config) - - torch.distributed.barrier(group=self.process_group) - self.load_weights( - model, - filenames, - quantize=quantize, - device=self.device, - rank=self.rank, - world_size=self.world_size, - ) - self.model = model.eval().to(dtype) - torch.distributed.barrier(group=self.process_group) - - @staticmethod - def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, - ): - parameters = dict(model.named_parameters()) - for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" - ) as f: - for name in f.keys(): - full_name = f"transformer.{name}" - - module_name, param_name = full_name.rsplit(".", 1) - module = model.get_submodule(module_name) - current_tensor = parameters[full_name] - - slice_ = f.get_slice(name) - - if isinstance(module, TensorParallelColumnLinear): - if param_name == "weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - tensor = tensor.transpose(1, 0) - else: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - elif isinstance(module, TensorParallelRowLinear): - if param_name == "weight": - size = slice_.get_shape()[1] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[:, start:stop] - tensor = tensor.transpose(1, 0) - else: - tensor = slice_[:] - # XXX: Hack for Rowlinear to add the bias only once. - if rank != 0: - tensor = torch.zeros_like(tensor) - elif isinstance(module, TensorParallelEmbedding): - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - tensor = slice_[:] - - if current_tensor.shape != tensor.shape: - raise ValueError( - f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" - ) - - tensor = tensor.contiguous() - - if quantize: - if not HAS_BITS_AND_BYTES: - raise ImportError( - "bitsandbytes is not available on your machine" - ) - - if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" - ): - tensor = Int8Params( - tensor.transpose(1, 0), - has_fp16_weights=False, - requires_grad=False, - ).to(device) - state = bnb.MatmulLtState() - state.threshold = 6.0 - state.has_fp16_weights = False - state.memory_efficient_backward = False - state.use_pool = True - state.CB = tensor.CB - state.SCB = tensor.SCB - tensor.CB = None - tensor.SCB = None - - def replace_linear(state, in_features, out_features): - def linear(input, weight, bias): - size_out = input.size()[:-1] + (out_features,) - input = input.view(-1, in_features) - out = torch.empty( - size_out, device=input.device, dtype=input.dtype - ) - out = bnb.matmul( - input, - weight, - out=out.view(-1, out_features), - state=state, - threshold=state.threshold, - bias=bias, - ) - - if state.CB is not None: - # we converted 8-bit row major to turing/ampere format - # in the first inference pass - # we no longer need the row-major weight - del state.CB - weight.data = state.CxB - - return out.view(size_out) - - return linear - - module.linear = replace_linear( - state, module.in_features, module.out_features - ) - - else: - tensor = tensor.to(device) - - module._parameters[param_name] = tensor - if name == "word_embeddings.weight": - model.lm_head._parameters["weight"] = tensor - - def forward(self, input_ids, attention_mask, past_key_values: Optional = None): - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - # Logits are sharded, so we need to gather them - logits_shard = outputs.logits[:, -1, :].contiguous() - - batch_size, vocab_shard_size = logits_shard.shape - vocab_size = self.world_size * vocab_shard_size - logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)] - torch.distributed.all_gather(logits, logits_shard, group=self.process_group) - logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) - - outputs.logits = logits - return outputs diff --git a/server/pyproject.toml b/server/pyproject.toml index 4e5e98b3..50f99398 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,11 +1,11 @@ [tool.poetry] -name = "bloom-inference" +name = "text-generation" version = "0.1.0" description = "BLOOM Inference Python gRPC Server" authors = ["Olivier Dehaene "] [tool.poetry.scripts] -bloom-inference-server = 'bloom_inference.cli:app' +text-generation-server = 'text_generation.cli:app' [tool.poetry.dependencies] python = "^3.9" @@ -17,6 +17,9 @@ accelerate = "^0.12.0" joblib = "^1.2.0" bitsandbytes = "^0.35.1" +[tool.poetry.extras] +bnb = ["bitsandbytes"] + [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.49.1" diff --git a/server/bloom_inference/__init__.py b/server/text_generation/__init__.py similarity index 100% rename from server/bloom_inference/__init__.py rename to server/text_generation/__init__.py diff --git a/server/bloom_inference/cache.py b/server/text_generation/cache.py similarity index 91% rename from server/bloom_inference/cache.py rename to server/text_generation/cache.py index 6812b306..65ec3e7c 100644 --- a/server/bloom_inference/cache.py +++ b/server/text_generation/cache.py @@ -1,6 +1,7 @@ -from bloom_inference.model import Batch from typing import Dict, Optional +from text_generation.models.types import Batch + class Cache: def __init__(self): diff --git a/server/bloom_inference/cli.py b/server/text_generation/cli.py similarity index 83% rename from server/bloom_inference/cli.py rename to server/text_generation/cli.py index 6360aec0..c41b9751 100644 --- a/server/bloom_inference/cli.py +++ b/server/text_generation/cli.py @@ -3,7 +3,7 @@ import typer from pathlib import Path -from bloom_inference import server, utils +from text_generation import server, utils app = typer.Typer() @@ -13,7 +13,7 @@ def serve( model_name: str, sharded: bool = False, quantize: bool = False, - uds_path: Path = "/tmp/bloom-inference", + uds_path: Path = "/tmp/text-generation", ): if sharded: assert ( @@ -35,8 +35,9 @@ def serve( @app.command() def download_weights( model_name: str, + extension: str = ".safetensors", ): - utils.download_weights(model_name) + utils.download_weights(model_name, extension) if __name__ == "__main__": diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py new file mode 100644 index 00000000..d4f3cf8b --- /dev/null +++ b/server/text_generation/models/__init__.py @@ -0,0 +1,22 @@ +from text_generation.models.model import Model +from text_generation.models.bloom import BLOOMSharded + +__all__ = ["Model", "BLOOMSharded"] + + +def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: + + if model_name.startswith("bigscience/bloom"): + if sharded: + return BLOOMSharded(model_name, quantize) + else: + if quantize: + raise ValueError("quantization is not supported for non-sharded BLOOM") + return Model(model_name) + else: + if sharded: + raise ValueError("sharded is only supported for BLOOM") + if quantize: + raise ValueError("Quantization is only supported for BLOOM models") + + return Model(model_name) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py new file mode 100644 index 00000000..172ca38d --- /dev/null +++ b/server/text_generation/models/bloom.py @@ -0,0 +1,231 @@ +import torch +import torch.distributed + +from typing import List, Optional + +from accelerate import init_empty_weights +from safetensors import safe_open +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers.models.bloom.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + +from text_generation.models import Model +from text_generation.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, +) + +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except Exception as e: + HAS_BITS_AND_BYTES = False + +torch.manual_seed(0) + + +class BLOOMSharded(Model): + def __init__(self, model_name: str, quantize: bool = False): + super(Model, self).__init__() + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{self.rank}") + dtype = torch.float16 + else: + self.device = torch.device("cpu") + dtype = torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + + config = AutoConfig.from_pretrained( + model_name, slow_but_exact=False, tp_parallel=True + ) + config.pad_token_id = 3 + self.num_heads = config.n_head // self.process_group.size() + + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False + # in PyTorch 1.12 and later. + torch.backends.cuda.matmul.allow_tf32 = True + + # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. + torch.backends.cudnn.allow_tf32 = True + + # Only download weights for small models + if self.master and model_name == "bigscience/bloom-560m": + download_weights(model_name, extension=".safetensors") + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_name, extension=".safetensors") + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=self.device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + full_name = f"transformer.{name}" + + module_name, param_name = full_name.rsplit(".", 1) + module = model.get_submodule(module_name) + current_tensor = parameters[full_name] + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + if param_name == "weight": + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + tensor = tensor.transpose(1, 0) + else: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + tensor = tensor.transpose(1, 0) + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + tensor = slice_[:] + + if current_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor.transpose(1, 0), + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state, in_features, out_features): + def linear(input, weight, bias): + size_out = input.size()[:-1] + (out_features,) + input = input.view(-1, in_features) + out = torch.empty( + size_out, device=input.device, dtype=input.dtype + ) + out = bnb.matmul( + input, + weight, + out=out.view(-1, out_features), + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out.view(size_out) + + return linear + + module.linear = replace_linear( + state, module.in_features, module.out_features + ) + + else: + tensor = tensor.to(device) + + module._parameters[param_name] = tensor + if name == "word_embeddings.weight": + model.lm_head._parameters["weight"] = tensor + + def forward(self, input_ids, attention_mask, past_key_values: Optional = None): + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + # Logits are sharded, so we need to gather them + logits_shard = outputs.logits[:, -1, :].contiguous() + + batch_size, vocab_shard_size = logits_shard.shape + vocab_size = self.world_size * vocab_shard_size + logits = [torch.empty_like(logits_shard) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, logits_shard, group=self.process_group) + logits = torch.cat(logits, dim=1).view(batch_size, 1, vocab_size) + + outputs.logits = logits + return outputs diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py new file mode 100644 index 00000000..db1367b9 --- /dev/null +++ b/server/text_generation/models/model.py @@ -0,0 +1,166 @@ +import torch +import torch.distributed + +from typing import List, Tuple, Optional +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers.modeling_outputs import CausalLMOutputWithPast + +from text_generation.models.types import Batch, GeneratedText + + +class Model: + def __init__(self, model_name: str): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + dtype = torch.float16 + else: + self.device = torch.device("cpu") + dtype = torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map="auto" + ).eval() + + self.num_heads = self.model.config.num_attention_heads + + def forward( + self, input_ids, attention_mask, past_key_values: Optional = None + ) -> CausalLMOutputWithPast: + # Model Forward + return self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + def generate_token( + self, batch: Batch + ) -> Tuple[List[GeneratedText], Optional[Batch]]: + # For some reason, inference_mode does not work well with GLOO which we use on CPU + context_manager = ( + torch.no_grad if self.device.type == "cpu" else torch.inference_mode + ) + with context_manager(): + outputs = self.forward(**batch.input_ids) + + # List of indices to cache + next_batch_keep_indices = [] + next_batch_past_keep_indices = [] + + # New input_ids for next forward + next_batch_input_ids = [] + next_batch_all_input_ids = [] + next_all_input_lengths = [] + + next_batch_size = 0 + next_batch_max_sequence_length = 0 + + # Finished requests + generated_texts: List[GeneratedText] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.all_input_lengths, + outputs.logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + logits, + next_token_chooser, + stopping_criteria, + all_tokens, + ) in enumerate(iterator): + # Select next token + next_token = next_token_chooser(all_tokens, logits.unsqueeze(0)[:, -1]) + + # Append next token to all tokens + all_tokens = torch.cat([all_tokens, next_token]) + + # Evaluate stopping criteria + if stopping_criteria(all_tokens): + # Decode all tokens + output = self.tokenizer.decode( + all_tokens.squeeze(-1), skip_special_tokens=True + ) + # Add to the list of finished generations with the original request + generated_texts.append(GeneratedText(request, output)) + # add to the next batch + else: + next_batch_keep_indices.append(i) + # past_key_values is of shape [batch_size * num_heads, ...] + # so we need to take into account the `num_heads` stride here + next_batch_past_keep_indices.extend( + [j for j in range(i * self.num_heads, (i + 1) * self.num_heads)] + ) + next_batch_input_ids.append(next_token) + next_batch_all_input_ids.append(all_tokens) + next_batch_size += 1 + new_input_length = input_length + 1 + next_all_input_lengths.append(new_input_length) + next_batch_max_sequence_length = max( + next_batch_max_sequence_length, new_input_length + ) + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generated_texts, None + + # If we finished at least one generation + next_batch_input_ids = {"input_ids": torch.cat(next_batch_input_ids, dim=0)} + if generated_texts: + # Apply indices to attention mask, past key values and other items that need to be cached + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"][ + next_batch_keep_indices + ] + next_batch_input_ids["past_key_values"] = [ + ( + keys[next_batch_past_keep_indices], + values[next_batch_past_keep_indices], + ) + for keys, values in outputs["past_key_values"] + ] + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_input_ids["attention_mask"] = batch.input_ids["attention_mask"] + next_batch_input_ids["past_key_values"] = outputs["past_key_values"] + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Update attention_mask with padding as we added a new token to input_ids + next_batch_input_ids["attention_mask"] = torch.cat( + [ + next_batch_input_ids["attention_mask"], + torch.ones((next_batch_size, 1)).to(self.device), + ], + dim=1, + ) + + next_batch = Batch( + batch_id=batch.batch_id, + requests=next_batch_requests, + all_input_lengths=next_all_input_lengths, + input_ids=next_batch_input_ids, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + size=next_batch_size, + max_sequence_length=next_batch_max_sequence_length, + ) + return generated_texts, next_batch diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py new file mode 100644 index 00000000..39c33ab7 --- /dev/null +++ b/server/text_generation/models/types.py @@ -0,0 +1,206 @@ +import torch + +from dataclasses import dataclass +from typing import List, Dict + +from transformers import AutoTokenizer + +from text_generation.pb import generate_pb2 +from text_generation.utils import NextTokenChooser, StoppingCriteria + + +@dataclass +class Batch: + batch_id: int + requests: List[generate_pb2.Request] + all_input_lengths: List[int] + input_ids: Dict[str, torch.Tensor] + all_input_ids: List[torch.Tensor] + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + size: int + max_sequence_length: int + + def to_pb(self): + return generate_pb2.Batch( + id=self.batch_id, + requests=self.requests, + size=self.size, + max_sequence_length=self.max_sequence_length, + ) + + @classmethod + def from_pb( + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + ) -> "Batch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + all_input_lengths = [] + + # Parse batch + for r in pb.requests: + inputs.append(r.inputs) + all_input_lengths.append(r.input_length) + next_token_choosers.append( + NextTokenChooser( + temperature=r.parameters.temperature, + top_k=r.parameters.top_k, + top_p=r.parameters.top_p, + do_sample=r.parameters.do_sample, + ) + ) + stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) + + input_ids = tokenizer( + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + ).to(device) + all_input_ids = input_ids["input_ids"].unsqueeze(-1) + + return cls( + batch_id=pb.id, + requests=pb.requests, + all_input_lengths=all_input_lengths, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=pb.size, + max_sequence_length=pb.max_sequence_length, + ) + + @classmethod + def concatenate(cls, batches: List["Batch"]) -> "Batch": + # Used for padding + total_batch_size = sum(batch.size for batch in batches) + max_sequence_length = max(batch.max_sequence_length for batch in batches) + + # Batch attributes + input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} + requests = [] + all_input_lengths = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + all_input_lengths.extend(batch.all_input_lengths) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Slicing end index for this batch + end_index = start_index + batch.size + + # We only concatenate batches that did at least one step + if batch.input_ids["input_ids"].shape[1] > 1: + raise ValueError("Batch input_ids should be of shape (batch_size, 1)") + + # Initialize tensors + if i == 0: + input_ids["input_ids"] = torch.empty( + (total_batch_size, 1), + dtype=batch.input_ids["input_ids"].dtype, + device=batch.input_ids["input_ids"].device, + ) + input_ids["attention_mask"] = torch.zeros( + (total_batch_size, max_sequence_length), + dtype=batch.input_ids["attention_mask"].dtype, + device=batch.input_ids["attention_mask"].device, + ) + + # input_ids["input_ids"] is always of shape [batch_size, 1] + # We do not need to pad it + input_ids["input_ids"][start_index:end_index] = batch.input_ids["input_ids"] + + # We need to slice the attention mask to remove padding from previous steps + input_ids["attention_mask"][ + start_index:end_index, -batch.max_sequence_length : + ] = batch.input_ids["attention_mask"][:, -batch.max_sequence_length :] + + for j, past in enumerate(batch.input_ids["past_key_values"]): + past_keys = past[0] + past_values = past[1] + + _, head_dim, padded_sequence_length = past_keys.shape + + # Reshape the tensors to make slicing easier + past_keys = past_keys.view( + batch.size, -1, head_dim, padded_sequence_length + ) + past_values = past_values.view( + batch.size, -1, padded_sequence_length, head_dim + ) + num_heads = past_keys.shape[1] + + # Initialize tensors + # This will run only once per layer + if j == len(input_ids["past_key_values"]): + padded_past_keys = torch.zeros( + ( + total_batch_size, + num_heads, + head_dim, + max_sequence_length - 1, + ), + dtype=past_keys.dtype, + device=past_keys.device, + ) + padded_past_values = torch.zeros( + ( + total_batch_size, + num_heads, + max_sequence_length - 1, + head_dim, + ), + dtype=past_values.dtype, + device=past_values.device, + ) + input_ids["past_key_values"].append( + [padded_past_keys, padded_past_values] + ) + + # We slice the past keys and values to remove the padding from previous batches + input_ids["past_key_values"][j][0][ + start_index:end_index, :, :, -(batch.max_sequence_length - 1) : + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] + + input_ids["past_key_values"][j][1][ + start_index:end_index, :, -(batch.max_sequence_length - 1) :, : + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + + # If we are on the last batch, we need to reshape the tensors + if (i + 1) == len(batches): + input_ids["past_key_values"][j][0] = input_ids["past_key_values"][ + j + ][0].view(total_batch_size * num_heads, head_dim, -1) + input_ids["past_key_values"][j][1] = input_ids["past_key_values"][ + j + ][1].view(total_batch_size * num_heads, -1, head_dim) + + start_index += batch.size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + all_input_lengths=all_input_lengths, + input_ids=input_ids, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + size=total_batch_size, + max_sequence_length=max_sequence_length, + ) + + +@dataclass +class GeneratedText: + request: generate_pb2.Request + output: str + + def to_pb(self) -> generate_pb2.GeneratedText: + return generate_pb2.GeneratedText(request=self.request, output=self.output) diff --git a/server/bloom_inference/pb/.gitignore b/server/text_generation/pb/.gitignore similarity index 100% rename from server/bloom_inference/pb/.gitignore rename to server/text_generation/pb/.gitignore diff --git a/server/bloom_inference/server.py b/server/text_generation/server.py similarity index 83% rename from server/bloom_inference/server.py rename to server/text_generation/server.py index ad40a52a..fffeb0ba 100644 --- a/server/bloom_inference/server.py +++ b/server/text_generation/server.py @@ -5,15 +5,16 @@ from grpc import aio from grpc_reflection.v1alpha import reflection from pathlib import Path -from typing import Optional, List +from typing import List -from bloom_inference.cache import Cache -from bloom_inference.model import BLOOM, Batch, BLOOMSharded -from bloom_inference.pb import generate_pb2_grpc, generate_pb2 +from text_generation.cache import Cache +from text_generation.models import Model, get_model +from text_generation.models.types import Batch +from text_generation.pb import generate_pb2_grpc, generate_pb2 class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): - def __init__(self, model: BLOOM, cache: Cache, server_urls: List[str]): + def __init__(self, model: Model, cache: Cache, server_urls: List[str]): self.cache = cache self.model = model self.server_urls = server_urls @@ -78,21 +79,17 @@ def serve( ): unix_socket_template = "unix://{}-{}" if sharded: - model = BLOOMSharded(model_name, quantize) server_urls = [ unix_socket_template.format(uds_path, rank) - for rank in range(model.world_size) + for rank in range(int(os.environ["WORLD_SIZE"])) ] - local_url = server_urls[model.rank] + local_url = server_urls[int(os.environ["RANK"])] else: - if quantize: - raise ValueError( - "bitsandbytes quantization is only available when running in `sharded` mode." - ) - model = BLOOM(model_name) local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] + model = get_model(model_name, sharded, quantize) + server = aio.server() generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( TextGenerationService(model, Cache(), server_urls), server diff --git a/server/bloom_inference/utils.py b/server/text_generation/utils.py similarity index 89% rename from server/bloom_inference/utils.py rename to server/text_generation/utils.py index cca5403f..926fc69c 100644 --- a/server/bloom_inference/utils.py +++ b/server/text_generation/utils.py @@ -92,19 +92,17 @@ def initialize_torch_distributed(): return torch.distributed.distributed_c10d._get_default_group(), rank, world_size -def weight_hub_files(model_name): +def weight_hub_files(model_name, extension=".safetensors"): """Get the safetensors filenames on the hub""" api = HfApi() info = api.model_info(model_name) - filenames = [ - s.rfilename for s in info.siblings if s.rfilename.endswith(".safetensors") - ] + filenames = [s.rfilename for s in info.siblings if s.rfilename.endswith(extension)] return filenames -def weight_files(model_name): +def weight_files(model_name, extension=".safetensors"): """Get the local safetensors filenames""" - filenames = weight_hub_files(model_name) + filenames = weight_hub_files(model_name, extension) files = [] for filename in filenames: cache_file = try_to_load_from_cache(model_name, filename=filename) @@ -112,16 +110,16 @@ def weight_files(model_name): raise LocalEntryNotFoundError( f"File {filename} of model {model_name} not found in " f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " - f"Please run `bloom-inference-server download-weights {model_name}` first." + f"Please run `text-generation-server download-weights {model_name}` first." ) files.append(cache_file) return files -def download_weights(model_name): +def download_weights(model_name, extension=".safetensors"): """Download the safetensors files from the hub""" - filenames = weight_hub_files(model_name) + filenames = weight_hub_files(model_name, extension) download_function = partial( hf_hub_download, repo_id=model_name, local_files_only=False