From 5e5d8766a2c2a28b166fbecf7543f19a4bf2c9d8 Mon Sep 17 00:00:00 2001 From: Olivier Dehaene Date: Mon, 17 Oct 2022 14:59:00 +0200 Subject: [PATCH] feat: Improve error handling --- Dockerfile | 3 +- README.md | 1 - aml/deployment.yaml | 14 +++---- router/Cargo.lock | 33 ++++++++-------- router/Cargo.toml | 10 ++++- router/client/src/client.rs | 34 +++++++---------- router/client/src/lib.rs | 22 ++++++----- router/client/src/sharded_client.rs | 23 +++++------ router/src/batcher.rs | 44 ++++++++++++++++----- router/src/db.rs | 59 ++++++++++++++++------------- router/src/lib.rs | 8 ++++ router/src/main.rs | 23 ++--------- router/src/server.rs | 53 ++++++++++---------------- router/src/validation.rs | 42 ++++++++++++++++---- run.sh | 17 +++++++-- server/bloom_inference/cli.py | 42 ++++++++++++++++++++ server/bloom_inference/main.py | 30 --------------- server/bloom_inference/server.py | 6 +-- server/bloom_inference/utils.py | 3 ++ server/pyproject.toml | 3 ++ 20 files changed, 267 insertions(+), 203 deletions(-) create mode 100644 router/src/lib.rs mode change 100755 => 100644 run.sh create mode 100644 server/bloom_inference/cli.py delete mode 100644 server/bloom_inference/main.py diff --git a/Dockerfile b/Dockerfile index a5161020..57c19ee2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \ MODEL_NAME=bigscience/bloom \ NUM_GPUS=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + NCCL_ASYNC_ERROR_HANDLING=1 \ CUDA_HOME=/usr/local/cuda \ LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \ CONDA_DEFAULT_ENV=text-generation \ @@ -51,7 +52,7 @@ RUN cd server && \ /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir # Install router -COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference +COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router COPY run.sh . RUN chmod +x run.sh diff --git a/README.md b/README.md index a62daca1..78da4db2 100644 --- a/README.md +++ b/README.md @@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire - [ ] Add tests - [ ] Add shutdown logic in router and server - [ ] Improve multi-processing logic in server -- [ ] Improve error handling everywhere - [ ] Improve past key layer indexing? \ No newline at end of file diff --git a/aml/deployment.yaml b/aml/deployment.yaml index 31cb09c5..be28ceef 100644 --- a/aml/deployment.yaml +++ b/aml/deployment.yaml @@ -8,7 +8,7 @@ environment_variables: MODEL_NAME: bigscience/bloom NUM_GPUS: 8 environment: - image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1 + image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3 inference_config: liveness_route: port: 3000 @@ -24,15 +24,15 @@ request_settings: request_timeout_ms: 90000 max_concurrent_requests_per_instance: 256 liveness_probe: - initial_delay: 300 + initial_delay: 600 timeout: 20 - period: 60 + period: 120 success_threshold: 1 - failure_threshold: 60 + failure_threshold: 3 readiness_probe: - initial_delay: 300 + initial_delay: 600 timeout: 20 - period: 60 + period: 120 success_threshold: 1 - failure_threshold: 60 + failure_threshold: 3 instance_count: 1 diff --git a/router/Cargo.lock b/router/Cargo.lock index 1f00df14..c912be89 100644 --- a/router/Cargo.lock +++ b/router/Cargo.lock @@ -149,22 +149,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "bloom-inference" -version = "0.1.0" -dependencies = [ - "axum", - "bloom-inference-client", - "futures", - "parking_lot", - "serde", - "serde_json", - "tokenizers", - "tokio", - "tracing", - "tracing-subscriber", -] - [[package]] name = "bloom-inference-client" version = "0.1.0" @@ -1669,6 +1653,23 @@ dependencies = [ "winapi", ] +[[package]] +name = "text-generation-router" +version = "0.1.0" +dependencies = [ + "axum", + "bloom-inference-client", + "futures", + "parking_lot", + "serde", + "serde_json", + "thiserror", + "tokenizers", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/router/Cargo.toml b/router/Cargo.toml index c5e5bb89..95666dcf 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -1,8 +1,15 @@ [package] -name = "bloom-inference" +name = "text-generation-router" version = "0.1.0" edition = "2021" +[lib] +path = "src/lib.rs" + +[[bin]] +name = "text-generation-router" +path = "src/main.rs" + [dependencies] axum = { version = "0.5.16", features = ["json", "serde_json"] } bloom-inference-client = { path = "client" } @@ -10,6 +17,7 @@ futures = "0.3.24" parking_lot = "0.12.1" serde = "1.0.145" serde_json = "1.0.85" +thiserror = "1.0.37" tokenizers = "0.13.0" tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] } tracing = "0.1.36" diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 712c13da..e7189b89 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -1,45 +1,37 @@ use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::*; use crate::Result; -use std::time::Duration; use tonic::transport::{Channel, Uri}; -use tower::timeout::Timeout; use tracing::*; /// BLOOM Inference gRPC client #[derive(Clone)] pub struct Client { - stub: TextGenerationServiceClient>, + stub: TextGenerationServiceClient, } impl Client { - /// Returns a client connected to the given url. Requests exceeding timeout will fail. - pub async fn connect(uri: Uri, timeout: Duration) -> Self { - let channel = Channel::builder(uri) - .connect() - .await - .expect("Transport error"); - let timeout_channel = Timeout::new(channel, timeout); + /// Returns a client connected to the given url + pub async fn connect(uri: Uri) -> Result { + let channel = Channel::builder(uri).connect().await?; - Self { - stub: TextGenerationServiceClient::new(timeout_channel), - } + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) } - /// Returns a client connected to the given unix socket. Requests exceeding timeout will fail. - pub async fn connect_uds(path: String, timeout: Duration) -> Self { + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { let channel = Channel::from_shared("http://[::]:50051".to_string()) .unwrap() .connect_with_connector(tower::service_fn(move |_: Uri| { tokio::net::UnixStream::connect(path.clone()) })) - .await - .expect("Transport error"); - let timeout_channel = Timeout::new(channel, timeout); + .await?; - Self { - stub: TextGenerationServiceClient::new(timeout_channel), - } + Ok(Self { + stub: TextGenerationServiceClient::new(channel), + }) } #[instrument(skip(self))] diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index bb9919e2..48b2650d 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -8,22 +8,26 @@ pub use client::Client; pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request}; pub use sharded_client::ShardedClient; use thiserror::Error; -pub use tonic::transport::Uri; +pub use tonic::transport; use tonic::Status; #[derive(Error, Debug, Clone)] -#[error("Text generation client error: {msg:?}")] -pub struct ClientError { - msg: String, - // source: Status, +pub enum ClientError { + #[error("Could not connect to Text Generation server: {0:?}")] + Connection(String), + #[error("Server error: {0:?}")] + Generation(String), } impl From for ClientError { fn from(err: Status) -> Self { - Self { - msg: err.to_string(), - // source: err, - } + Self::Generation(err.to_string()) + } +} + +impl From for ClientError { + fn from(err: transport::Error) -> Self { + Self::Connection(err.to_string()) } } diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 7af741f0..7134551e 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,7 +1,6 @@ use crate::Result; use crate::{Batch, Client, GeneratedText}; use futures::future::join_all; -use std::time::Duration; use tokio::sync::{broadcast, mpsc}; use tonic::transport::Uri; @@ -69,24 +68,22 @@ impl ShardedClient { Self { request_tx } } - async fn from_master_client(mut master_client: Client) -> Self { + async fn from_master_client(mut master_client: Client) -> Result { let uris = master_client.service_discovery().await.unwrap(); - let futures = uris - .into_iter() - .map(|path| Client::connect_uds(path, Duration::from_secs(5))); - let clients = join_all(futures).await; - Self::new(clients) + let futures = uris.into_iter().map(|path| Client::connect_uds(path)); + let clients: Result> = join_all(futures).await.into_iter().collect(); + Ok(Self::new(clients?)) } - /// Returns a client connected to the given url. Requests exceeding timeout will fail. - pub async fn connect(uri: Uri, timeout: Duration) -> Self { - let master_client = Client::connect(uri, timeout).await; + /// Returns a client connected to the given url + pub async fn connect(uri: Uri) -> Result { + let master_client = Client::connect(uri).await?; Self::from_master_client(master_client).await } - /// Returns a client connected to the given unix socket. Requests exceeding timeout will fail. - pub async fn connect_uds(path: String, timeout: Duration) -> Self { - let master_client = Client::connect_uds(path, timeout).await; + /// Returns a client connected to the given unix socket + pub async fn connect_uds(path: String) -> Result { + let master_client = Client::connect_uds(path).await?; Self::from_master_client(master_client).await } diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 2025cf62..94c72cc5 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -1,13 +1,30 @@ use crate::server::GenerateRequest; -use crate::Db; +use crate::{Db, Entry}; +use axum::http::StatusCode; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use std::future::Future; use std::sync::Arc; +use thiserror::Error; use tokio::sync::{oneshot, Notify}; const MAX_LENGTH: usize = 128; -pub struct InferError {} +#[derive(Debug, Error)] +pub enum InferError { + #[error("Request failed during generation: {0}")] + GenerationError(String), + #[error("Model is overloaded")] + Overloaded, +} + +impl From for (StatusCode, String) { + fn from(err: InferError) -> Self { + match err { + InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), + InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()), + } + } +} #[derive(Clone)] pub(crate) struct Batcher { @@ -37,14 +54,18 @@ impl Batcher { request: GenerateRequest, ) -> Result { if self.db.len() > MAX_LENGTH { - return Err(InferError {}); + return Err(InferError::Overloaded); } let (request_tx, request_rx) = oneshot::channel(); - self.db.append(input_length, request, request_tx); + self.db.append(Entry { + request, + response_tx: request_tx, + input_length, + }); self.shared.batching_task.notify_waiters(); match request_rx.await.unwrap() { Ok(output) => Ok(output), - Err(_) => Err(InferError {}), + Err(err) => Err(InferError::GenerationError(err.to_string())), } } } @@ -108,7 +129,6 @@ async fn wrap_future( next_batch } Err(err) => { - println!("{:?}", err); send_error(err, request_ids, db); None } @@ -117,14 +137,18 @@ async fn wrap_future( fn send_error(error: ClientError, request_ids: Vec, db: &Db) { request_ids.into_iter().for_each(|id| { - let (_, response_tx) = db.remove(&id).unwrap(); - response_tx.send(Err(error.clone())).unwrap_or(()); + let entry = db.remove(&id).expect("ID not found in db. This is a bug."); + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Err(error.clone())).unwrap_or(()); }); } fn send_generated(finished: Vec, db: &Db) { finished.into_iter().for_each(|output| { - let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap(); - response_tx.send(Ok(output.output)).unwrap_or(()); + let entry = db + .remove(&output.request.unwrap().id) + .expect("ID not found in db. This is a bug."); + // unwrap_or is valid here as we don't care if the receiver is gone. + entry.response_tx.send(Ok(output.output)).unwrap_or(()); }); } diff --git a/router/src/db.rs b/router/src/db.rs index 5118b2fc..03593fc0 100644 --- a/router/src/db.rs +++ b/router/src/db.rs @@ -1,11 +1,29 @@ /// This code is massively inspired by Tokio mini-redis -use crate::server::GenerateRequest; +use crate::server::{GenerateParameters, GenerateRequest}; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use parking_lot::RwLock; use std::collections::BTreeMap; use std::sync::Arc; use tokio::sync::oneshot::Sender; +#[derive(Debug)] +pub(crate) struct Entry { + pub request: GenerateRequest, + pub response_tx: Sender>, + pub input_length: usize, +} + +impl From for LogitsWarperParameters { + fn from(parameters: GenerateParameters) -> Self { + Self { + temperature: parameters.temperature, + top_k: parameters.top_k as u32, + top_p: parameters.top_p, + do_sample: parameters.do_sample, + } + } +} + #[derive(Debug, Clone)] pub(crate) struct Db { pub shared: Arc, @@ -18,7 +36,7 @@ pub struct Shared { #[derive(Debug)] struct State { - entries: BTreeMap>)>, + entries: BTreeMap, /// Identifier to use for the next expiration. Each expiration is associated /// with a unique identifier. See above for why. @@ -44,37 +62,16 @@ impl Db { Self { shared } } - pub(crate) fn append( - &self, - input_length: usize, - request: GenerateRequest, - sender: Sender>, - ) { + pub(crate) fn append(&self, entry: Entry) { let mut state = self.shared.state.write(); let id = state.next_id; state.next_id += 1; - let parameters = Some(LogitsWarperParameters { - temperature: request.parameters.temperature, - top_k: request.parameters.top_k, - top_p: request.parameters.top_p, - do_sample: request.parameters.do_sample, - }); - let request = Request { - id, - inputs: request.inputs, - input_length: input_length as u32, - parameters, - max_new_tokens: request.parameters.max_new_tokens, - }; - state.entries.insert(id, (request, sender)); + state.entries.insert(id, entry); } - pub(crate) fn remove( - &self, - id: &u64, - ) -> Option<(Request, Sender>)> { + pub(crate) fn remove(&self, id: &u64) -> Option { let mut state = self.shared.state.write(); state.entries.remove(id) } @@ -91,7 +88,15 @@ impl Db { .entries .range(state.next_batch_start_id..) .take(max_size) - .map(|(_, (request, _))| request.clone()) + .map(|(id, entry)| Request { + id: *id, + inputs: entry.request.inputs.clone(), + input_length: entry.input_length as u32, + parameters: Some(LogitsWarperParameters::from( + entry.request.parameters.clone(), + )), + max_new_tokens: entry.request.parameters.max_new_tokens, + }) .collect(); if requests.is_empty() { diff --git a/router/src/lib.rs b/router/src/lib.rs new file mode 100644 index 00000000..09dbdd12 --- /dev/null +++ b/router/src/lib.rs @@ -0,0 +1,8 @@ +mod batcher; +mod db; +pub mod server; +mod validation; + +use batcher::Batcher; +use db::{Db, Entry}; +use validation::Validation; diff --git a/router/src/main.rs b/router/src/main.rs index 2fe02944..5169a071 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,21 +1,8 @@ use bloom_inference_client::ShardedClient; use std::net::SocketAddr; -use std::time::Duration; +use text_generation_router::server; use tokenizers::Tokenizer; -mod server; -mod validation; - -use validation::Validation; - -mod db; - -use db::Db; - -mod batcher; - -use batcher::Batcher; - fn main() -> Result<(), std::io::Error> { let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap(); @@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> { .block_on(async { tracing_subscriber::fmt::init(); - let sharded_client = ShardedClient::connect_uds( - "/tmp/bloom-inference-0".to_string(), - Duration::from_secs(5), - ) - .await; + let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string()) + .await + .expect("Could not connect to server"); sharded_client .clear_cache() .await diff --git a/router/src/server.rs b/router/src/server.rs index f113ce44..b8331da1 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,9 +1,9 @@ -use bloom_inference_client::ShardedClient; use crate::{Batcher, Validation}; use axum::extract::Extension; use axum::http::StatusCode; use axum::routing::{get, post}; use axum::{Json, Router}; +use bloom_inference_client::ShardedClient; use serde::Deserialize; use std::net::SocketAddr; use tokenizers::Tokenizer; @@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters { #[serde(default = "default_temperature")] pub temperature: f32, #[serde(default = "default_top_k")] - pub top_k: u32, + pub top_k: i32, #[serde(default = "default_top_p")] pub top_p: f32, #[serde(default = "default_do_sample")] @@ -28,7 +28,7 @@ fn default_temperature() -> f32 { 1.0 } -fn default_top_k() -> u32 { +fn default_top_k() -> i32 { 0 } @@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest { } #[instrument(skip(state), fields(time, time_per_token))] -async fn liveness(state: Extension) -> Result<(), StatusCode> { - let output = state +async fn liveness(state: Extension) -> Result<(), (StatusCode, String)> { + state .infer .infer( 1, @@ -78,50 +78,37 @@ async fn liveness(state: Extension) -> Result<(), StatusCode> { }, }, ) - .await; - - match output { - Ok(_) => Ok(()), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } + .await?; + Ok(()) } #[instrument(skip(state), fields(time, time_per_token))] async fn generate( state: Extension, req: Json, -) -> Result, StatusCode> { +) -> Result, (StatusCode, String)> { let start = Instant::now(); - let (input_length, validated_request) = match state + let (input_length, validated_request) = state .validation .validate(GenerateRequest { inputs: req.inputs.clone(), parameters: req.parameters.clone(), }) - .await - { - Ok(result) => result, - Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), - }; + .await?; - let output = state.infer.infer(input_length, validated_request).await; + let generated_text = state.infer.infer(input_length, validated_request).await?; - match output { - Ok(generated_text) => { - tracing::Span::current().record("time", format!("{:?}", start.elapsed())); - tracing::Span::current().record( - "time_per_token", - format!("{:?}", start.elapsed() / req.parameters.max_new_tokens), - ); - tracing::info!("response: {}", generated_text); + tracing::Span::current().record("time", format!("{:?}", start.elapsed())); + tracing::Span::current().record( + "time_per_token", + format!("{:?}", start.elapsed() / req.parameters.max_new_tokens), + ); + tracing::info!("response: {}", generated_text); - Ok(Json(serde_json::json!({ - "generated_text": generated_text, - }))) - } - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } + Ok(Json(serde_json::json!({ + "generated_text": generated_text, + }))) } #[derive(Clone)] diff --git a/router/src/validation.rs b/router/src/validation.rs index 6987894d..c9d391aa 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,9 +1,28 @@ use crate::server::GenerateRequest; +use axum::http::StatusCode; +use thiserror::Error; use tokenizers::tokenizer::Tokenizer; use tokio::sync::{mpsc, oneshot}; -#[derive(Debug)] -pub struct ValidationError {} +#[derive(Error, Debug)] +pub enum ValidationError { + #[error("Temperature must be strictly positive")] + Temperature, + #[error("Top p must be <= 0.0 or > 1.0")] + TopP, + #[error("Top k must be strictly positive")] + TopK, + #[error("Max New Tokens must be < 512")] + MaxNewTokens, + #[error("Inputs must have less than 512 tokens. Given: {0}")] + InputLength(usize), +} + +impl From for (StatusCode, String) { + fn from(err: ValidationError) -> Self { + (StatusCode::BAD_REQUEST, err.to_string()) + } +} type ValidationRequest = ( GenerateRequest, @@ -39,15 +58,23 @@ impl Validation { async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver) { while let Some((request, response_tx)) = receiver.recv().await { if request.parameters.temperature < 0.0 { - response_tx.send(Err(ValidationError {})).unwrap_or(()); + response_tx + .send(Err(ValidationError::Temperature)) + .unwrap_or(()); continue; } if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { - response_tx.send(Err(ValidationError {})).unwrap_or(()); + response_tx.send(Err(ValidationError::TopP)).unwrap_or(()); + continue; + } + if request.parameters.top_k < 0 { + response_tx.send(Err(ValidationError::TopK)).unwrap_or(()); continue; } if request.parameters.max_new_tokens > 512 { - response_tx.send(Err(ValidationError {})).unwrap_or(()); + response_tx + .send(Err(ValidationError::MaxNewTokens)) + .unwrap_or(()); continue; } @@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver 512 { - response_tx.send(Err(ValidationError {})).unwrap_or(()); + response_tx + .send(Err(ValidationError::InputLength(input_length))) + .unwrap_or(()); continue; } response_tx.send(Ok((input_length, request))).unwrap_or(()); } - println!("drop here"); } diff --git a/run.sh b/run.sh old mode 100755 new mode 100644 index 3b095541..30303501 --- a/run.sh +++ b/run.sh @@ -1,10 +1,12 @@ #!/usr/bin/env bash -server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH" -$server_cmd & +server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH" +# Run in background +$server_cmd 2>&1 > /dev/null & + +# Check if server is running by checking if the unix socket is created FILE=/tmp/bloom-inference-0 - while : do if test -S "$FILE"; then @@ -18,4 +20,11 @@ while : sleep 1 -exec "bloom-inference" +# Run in background +text-generation-router & + +# Wait for any process to exit +wait -n + +# Exit with status of process that exited first +exit $? \ No newline at end of file diff --git a/server/bloom_inference/cli.py b/server/bloom_inference/cli.py new file mode 100644 index 00000000..a5f84e77 --- /dev/null +++ b/server/bloom_inference/cli.py @@ -0,0 +1,42 @@ +import typer + +from pathlib import Path +from torch.distributed.launcher import launch_agent, LaunchConfig +from typing import Optional + +from bloom_inference import server + +app = typer.Typer() + + +@app.command() +def launcher( + model_name: str, + num_gpus: int = 1, + shard_directory: Optional[Path] = None, +): + if num_gpus == 1: + serve(model_name, False, shard_directory) + + else: + config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=num_gpus, + rdzv_backend="c10d", + max_restarts=0, + ) + launch_agent(config, server.serve, [model_name, True, shard_directory]) + + +@app.command() +def serve( + model_name: str, + sharded: bool = False, + shard_directory: Optional[Path] = None, +): + server.serve(model_name, sharded, shard_directory) + + +if __name__ == "__main__": + app() diff --git a/server/bloom_inference/main.py b/server/bloom_inference/main.py deleted file mode 100644 index b81ed040..00000000 --- a/server/bloom_inference/main.py +++ /dev/null @@ -1,30 +0,0 @@ -import typer - -from pathlib import Path -from torch.distributed.launcher import launch_agent, LaunchConfig -from typing import Optional - -from bloom_inference.server import serve - - -def main( - model_name: str, - num_gpus: int = 1, - shard_directory: Optional[Path] = None, -): - if num_gpus == 1: - serve(model_name, False, shard_directory) - - else: - config = LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=num_gpus, - rdzv_backend="c10d", - max_restarts=0, - ) - launch_agent(config, serve, [model_name, True, shard_directory]) - - -if __name__ == "__main__": - typer.run(main) diff --git a/server/bloom_inference/server.py b/server/bloom_inference/server.py index 3a509169..e89706e0 100644 --- a/server/bloom_inference/server.py +++ b/server/bloom_inference/server.py @@ -1,4 +1,6 @@ import asyncio +import os + from grpc import aio from grpc_reflection.v1alpha import reflection @@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory): await server.wait_for_termination() asyncio.run(serve_inner(model_name, sharded, shard_directory)) - - -if __name__ == "__main__": - serve("bigscience/bloom-560m", True, Path("/tmp/models")) diff --git a/server/bloom_inference/utils.py b/server/bloom_inference/utils.py index db02dadb..fe2c913e 100644 --- a/server/bloom_inference/utils.py +++ b/server/bloom_inference/utils.py @@ -2,6 +2,8 @@ import os import contextlib import torch import torch.distributed + +from datetime import timedelta from transformers.generation_logits_process import ( LogitsProcessorList, TemperatureLogitsWarper, @@ -79,6 +81,7 @@ def initialize_torch_distributed(): backend=backend, world_size=world_size, rank=rank, + timeout=timedelta(seconds=60), init_method="tcp://localhost:6000", ) diff --git a/server/pyproject.toml b/server/pyproject.toml index 80d95426..1dd8ae27 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -4,6 +4,9 @@ version = "0.1.0" description = "BLOOM Inference Python gRPC Server" authors = ["Olivier Dehaene "] +[tool.poetry.scripts] +bloom-inference-server = 'bloom_inference.cli:app' + [tool.poetry.dependencies] python = "^3.9" protobuf = "^4.21.7"