feat: Add arguments to CLI

This commit is contained in:
Olivier Dehaene 2022-10-17 18:27:33 +02:00
parent 5e5d8766a2
commit 92c1ecd008
9 changed files with 163 additions and 35 deletions

View File

@ -9,7 +9,7 @@ WORKDIR /usr/src/router
RUN cargo install --path .
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
FROM nvidia/cuda:11.6.1-devel-ubuntu18.04
ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \

View File

@ -43,7 +43,6 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
## TODO:
- [ ] Add batching args to router CLI
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
- [ ] Add tests
- [ ] Add shutdown logic in router and server

102
router/Cargo.lock generated
View File

@ -253,6 +253,43 @@ dependencies = [
"vec_map",
]
[[package]]
name = "clap"
version = "4.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bf8832993da70a4c6d13c581f4463c2bdda27b9bf1c5498dc4365543abe6d6f"
dependencies = [
"atty",
"bitflags",
"clap_derive",
"clap_lex",
"once_cell",
"strsim 0.10.0",
"termcolor",
]
[[package]]
name = "clap_derive"
version = "4.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c42f169caba89a7d512b5418b09864543eeb4d497416c917d7137863bd2076ad"
dependencies = [
"heck 0.4.0",
"proc-macro-error",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8"
dependencies = [
"os_str_bytes",
]
[[package]]
name = "console"
version = "0.15.2"
@ -701,6 +738,12 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]]
name = "hermit-abi"
version = "0.1.19"
@ -1136,6 +1179,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "os_str_bytes"
version = "6.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff"
[[package]]
name = "parking_lot"
version = "0.12.1"
@ -1225,6 +1274,30 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro2"
version = "1.0.46"
@ -1251,7 +1324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62941722fb675d463659e49c4f3fe1fe792ff24fe5bbaa9c08cd3b98a1c354f5"
dependencies = [
"bytes",
"heck",
"heck 0.3.3",
"itertools 0.10.5",
"lazy_static",
"log",
@ -1601,6 +1674,12 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "syn"
version = "1.0.101"
@ -1643,6 +1722,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "termcolor"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755"
dependencies = [
"winapi-util",
]
[[package]]
name = "terminal_size"
version = "0.1.17"
@ -1659,6 +1747,7 @@ version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"clap 4.0.15",
"futures",
"parking_lot",
"serde",
@ -1742,7 +1831,7 @@ checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
dependencies = [
"aho-corasick",
"cached-path",
"clap",
"clap 2.34.0",
"derive_builder",
"dirs",
"esaxx-rs",
@ -2251,6 +2340,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"

View File

@ -2,6 +2,8 @@
name = "text-generation-router"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Webserver"
[lib]
path = "src/lib.rs"
@ -13,6 +15,7 @@ path = "src/main.rs"
[dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
parking_lot = "0.12.1"
serde = "1.0.145"

View File

@ -27,7 +27,7 @@ impl From<InferError> for (StatusCode, String) {
}
#[derive(Clone)]
pub(crate) struct Batcher {
pub struct Batcher {
db: Db,
shared: Arc<Shared>,
}
@ -37,13 +37,13 @@ struct Shared {
}
impl Batcher {
pub(crate) fn new(client: ShardedClient) -> Self {
pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
tokio::spawn(batching_task(client, db.clone(), shared.clone()));
tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
Self { db, shared }
}
@ -70,40 +70,46 @@ impl Batcher {
}
}
async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
async fn batching_task(max_batch_size: usize,
client: ShardedClient,
db: Db,
shared: Arc<Shared>) {
let limit_min_batch_size = (max_batch_size / 2) as u32;
loop {
shared.batching_task.notified().await;
if let Some(batch) = db.next_batch(32) {
if let Some(batch) = db.next_batch(max_batch_size) {
let request_ids = batch.requests.iter().map(|req| req.id).collect();
let mut cached_batch = match batch.size {
size if size > 16 => {
size if size > limit_min_batch_size => {
wrap_future(client.generate_until_finished(batch), request_ids, &db).await
}
_ => wrap_future(client.generate(batch), request_ids, &db).await,
};
while let Some(batch) = cached_batch {
let batch_size = batch.size;
let mut current_batch_size = batch.size;
let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
let mut batches = vec![batch];
if batch_size <= 16 {
if let Some(new_batch) = db.next_batch_minimum_size(16, 48) {
if current_batch_size <= limit_min_batch_size {
if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
let new_batch_request_ids =
new_batch.requests.iter().map(|req| req.id).collect();
let new_cached_batch =
wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
.await;
if let Some(new_cached_batch) = new_cached_batch {
current_batch_size += new_cached_batch.size;
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
}
cached_batch = match batch_size {
size if size > 16 => {
cached_batch = match current_batch_size {
size if size > limit_min_batch_size => {
wrap_future(
client.generate_until_finished_with_cache(batches),
request_ids,

View File

@ -1,8 +1,8 @@
mod batcher;
mod db;
pub mod server;
mod validation;
pub mod server;
use batcher::Batcher;
use db::{Db, Entry};
use batcher::Batcher;
use validation::Validation;

View File

@ -1,10 +1,36 @@
use bloom_inference_client::ShardedClient;
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_router::server;
use tokenizers::Tokenizer;
use clap::Parser;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "32", long, short, env)]
max_batch_size: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
}
fn main() -> Result<(), std::io::Error> {
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_batch_size,
port,
shard_uds_path,
tokenizer_name,
} = args;
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
tokio::runtime::Builder::new_multi_thread()
.enable_all()
@ -13,7 +39,7 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
tracing_subscriber::fmt::init();
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
let sharded_client = ShardedClient::connect_uds(shard_uds_path)
.await
.expect("Could not connect to server");
sharded_client
@ -22,9 +48,9 @@ fn main() -> Result<(), std::io::Error> {
.expect("Unable to clear cache");
tracing::info!("Connected");
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
server::run(sharded_client, tokenizer, addr).await;
server::run(max_batch_size, sharded_client, tokenizer, addr).await;
Ok(())
})
}

View File

@ -64,7 +64,7 @@ pub(crate) struct GenerateRequest {
#[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
state
.infer
.batcher
.infer(
1,
GenerateRequest {
@ -97,7 +97,7 @@ async fn generate(
})
.await?;
let generated_text = state.infer.infer(input_length, validated_request).await?;
let generated_text = state.batcher.infer(input_length, validated_request).await?;
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record(
@ -114,18 +114,14 @@ async fn generate(
#[derive(Clone)]
struct ServerState {
validation: Validation,
infer: Batcher,
batcher: Batcher,
}
pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
client.clear_cache().await.expect("Unable to clear cache");
tracing::info!("Connected");
let infer = Batcher::new(client);
pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) {
let batcher = Batcher::new(client, max_batch_size);
let validation = Validation::new(tokenizer);
let shared_state = ServerState { validation, infer };
let shared_state = ServerState { validation, batcher };
let app = Router::new()
.route("/generate", post(generate))

View File

@ -14,7 +14,7 @@ pub enum ValidationError {
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 512 tokens. Given: {0}")]
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
InputLength(usize),
}
@ -30,7 +30,7 @@ type ValidationRequest = (
);
#[derive(Debug, Clone)]
pub(crate) struct Validation {
pub struct Validation {
sender: mpsc::Sender<ValidationRequest>,
}
@ -81,7 +81,7 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
let input_length = inputs.len();
if input_length > 512 {
if input_length > 1000 {
response_tx
.send(Err(ValidationError::InputLength(input_length)))
.unwrap_or(());