feat: Add arguments to CLI
This commit is contained in:
parent
5e5d8766a2
commit
92c1ecd008
|
@ -9,7 +9,7 @@ WORKDIR /usr/src/router
|
|||
|
||||
RUN cargo install --path .
|
||||
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
|
||||
|
||||
ENV LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
|
|
|
@ -43,7 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
|
|||
|
||||
## TODO:
|
||||
|
||||
- [ ] Add batching args to router CLI
|
||||
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
|
||||
- [ ] Add tests
|
||||
- [ ] Add shutdown logic in router and server
|
||||
|
|
|
@ -253,6 +253,43 @@ dependencies = [
|
|||
"vec_map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags",
|
||||
"clap_derive",
|
||||
"clap_lex",
|
||||
"once_cell",
|
||||
"strsim 0.10.0",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
|
||||
dependencies = [
|
||||
"heck 0.4.0",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8"
|
||||
dependencies = [
|
||||
"os_str_bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.2"
|
||||
|
@ -701,6 +738,12 @@ dependencies = [
|
|||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
|
@ -1136,6 +1179,12 @@ dependencies = [
|
|||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "os_str_bytes"
|
||||
version = "6.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
|
@ -1225,6 +1274,30 @@ version = "0.2.16"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.46"
|
||||
|
@ -1251,7 +1324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck",
|
||||
"heck 0.3.3",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"log",
|
||||
|
@ -1601,6 +1674,12 @@ version = "0.9.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.101"
|
||||
|
@ -1643,6 +1722,15 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "terminal_size"
|
||||
version = "0.1.17"
|
||||
|
@ -1659,6 +1747,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"axum",
|
||||
"bloom-inference-client",
|
||||
"clap 4.0.15",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
|
@ -1742,7 +1831,7 @@ checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
|
|||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cached-path",
|
||||
"clap",
|
||||
"clap 2.34.0",
|
||||
"derive_builder",
|
||||
"dirs",
|
||||
"esaxx-rs",
|
||||
|
@ -2251,6 +2340,15 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
name = "text-generation-router"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
description = "Text Generation Webserver"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
@ -13,6 +15,7 @@ path = "src/main.rs"
|
|||
[dependencies]
|
||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||
bloom-inference-client = { path = "client" }
|
||||
clap = { version = "4.0.15", features = ["derive", "env"] }
|
||||
futures = "0.3.24"
|
||||
parking_lot = "0.12.1"
|
||||
serde = "1.0.145"
|
||||
|
|
|
@ -27,7 +27,7 @@ impl From<InferError> for (StatusCode, String) {
|
|||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Batcher {
|
||||
pub struct Batcher {
|
||||
db: Db,
|
||||
shared: Arc<Shared>,
|
||||
}
|
||||
|
@ -37,13 +37,13 @@ struct Shared {
|
|||
}
|
||||
|
||||
impl Batcher {
|
||||
pub(crate) fn new(client: ShardedClient) -> Self {
|
||||
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
|
||||
let db = Db::new();
|
||||
let shared = Arc::new(Shared {
|
||||
batching_task: Notify::new(),
|
||||
});
|
||||
|
||||
tokio::spawn(batching_task(client, db.clone(), shared.clone()));
|
||||
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
|
||||
|
||||
Self { db, shared }
|
||||
}
|
||||
|
@ -70,40 +70,46 @@ impl Batcher {
|
|||
}
|
||||
}
|
||||
|
||||
async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
|
||||
async fn batching_task(max_batch_size: usize,
|
||||
client: ShardedClient,
|
||||
db: Db,
|
||||
shared: Arc<Shared>) {
|
||||
let limit_min_batch_size = (max_batch_size / 2) as u32;
|
||||
|
||||
loop {
|
||||
shared.batching_task.notified().await;
|
||||
|
||||
if let Some(batch) = db.next_batch(32) {
|
||||
if let Some(batch) = db.next_batch(max_batch_size) {
|
||||
let request_ids = batch.requests.iter().map(|req| req.id).collect();
|
||||
let mut cached_batch = match batch.size {
|
||||
size if size > 16 => {
|
||||
size if size > limit_min_batch_size => {
|
||||
wrap_future(client.generate_until_finished(batch), request_ids, &db).await
|
||||
}
|
||||
_ => wrap_future(client.generate(batch), request_ids, &db).await,
|
||||
};
|
||||
|
||||
while let Some(batch) = cached_batch {
|
||||
let batch_size = batch.size;
|
||||
let mut current_batch_size = batch.size;
|
||||
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
|
||||
let mut batches = vec![batch];
|
||||
|
||||
if batch_size <= 16 {
|
||||
if let Some(new_batch) = db.next_batch_minimum_size(16, 48) {
|
||||
if current_batch_size <= limit_min_batch_size {
|
||||
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
|
||||
let new_batch_request_ids =
|
||||
new_batch.requests.iter().map(|req| req.id).collect();
|
||||
let new_cached_batch =
|
||||
wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
|
||||
.await;
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
current_batch_size += new_cached_batch.size;
|
||||
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cached_batch = match batch_size {
|
||||
size if size > 16 => {
|
||||
cached_batch = match current_batch_size {
|
||||
size if size > limit_min_batch_size => {
|
||||
wrap_future(
|
||||
client.generate_until_finished_with_cache(batches),
|
||||
request_ids,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
mod batcher;
|
||||
mod db;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
pub mod server;
|
||||
|
||||
use batcher::Batcher;
|
||||
use db::{Db, Entry};
|
||||
use batcher::Batcher;
|
||||
use validation::Validation;
|
||||
|
|
|
@ -1,10 +1,36 @@
|
|||
use bloom_inference_client::ShardedClient;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
use clap::Parser;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "32", long, short, env)]
|
||||
max_batch_size: usize,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
|
||||
shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_batch_size,
|
||||
port,
|
||||
shard_uds_path,
|
||||
tokenizer_name,
|
||||
} = args;
|
||||
|
||||
|
||||
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
|
||||
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
|
@ -13,7 +39,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.block_on(async {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
|
||||
let sharded_client = ShardedClient::connect_uds(shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
sharded_client
|
||||
|
@ -22,9 +48,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
||||
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
|
||||
|
||||
server::run(sharded_client, tokenizer, addr).await;
|
||||
server::run(max_batch_size, sharded_client, tokenizer, addr).await;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ pub(crate) struct GenerateRequest {
|
|||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
||||
state
|
||||
.infer
|
||||
.batcher
|
||||
.infer(
|
||||
1,
|
||||
GenerateRequest {
|
||||
|
@ -97,7 +97,7 @@ async fn generate(
|
|||
})
|
||||
.await?;
|
||||
|
||||
let generated_text = state.infer.infer(input_length, validated_request).await?;
|
||||
let generated_text = state.batcher.infer(input_length, validated_request).await?;
|
||||
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record(
|
||||
|
@ -114,18 +114,14 @@ async fn generate(
|
|||
#[derive(Clone)]
|
||||
struct ServerState {
|
||||
validation: Validation,
|
||||
infer: Batcher,
|
||||
batcher: Batcher,
|
||||
}
|
||||
|
||||
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
|
||||
client.clear_cache().await.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let infer = Batcher::new(client);
|
||||
|
||||
pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
|
||||
let batcher = Batcher::new(client, max_batch_size);
|
||||
let validation = Validation::new(tokenizer);
|
||||
|
||||
let shared_state = ServerState { validation, infer };
|
||||
let shared_state = ServerState { validation, batcher };
|
||||
|
||||
let app = Router::new()
|
||||
.route("/generate", post(generate))
|
||||
|
|
|
@ -14,7 +14,7 @@ pub enum ValidationError {
|
|||
TopK,
|
||||
#[error("Max New Tokens must be < 512")]
|
||||
MaxNewTokens,
|
||||
#[error("Inputs must have less than 512 tokens. Given: {0}")]
|
||||
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
|
||||
InputLength(usize),
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ type ValidationRequest = (
|
|||
);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Validation {
|
||||
pub struct Validation {
|
||||
sender: mpsc::Sender<ValidationRequest>,
|
||||
}
|
||||
|
||||
|
@ -81,7 +81,7 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
|||
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
|
||||
let input_length = inputs.len();
|
||||
|
||||
if input_length > 512 {
|
||||
if input_length > 1000 {
|
||||
response_tx
|
||||
.send(Err(ValidationError::InputLength(input_length)))
|
||||
.unwrap_or(());
|
||||
|
|
Loading…
Reference in New Issue