feat: Improve error handling

This commit is contained in:
Olivier Dehaene 2022-10-17 14:59:00 +02:00
parent 00e6ce44b1
commit 5e5d8766a2
20 changed files with 267 additions and 203 deletions

View File

@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \
MODEL_NAME=bigscience/bloom \ MODEL_NAME=bigscience/bloom \
NUM_GPUS=8 \ NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \ CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \ LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \ CONDA_DEFAULT_ENV=text-generation \
@ -51,7 +52,7 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir /opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router # Install router
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY run.sh . COPY run.sh .
RUN chmod +x run.sh RUN chmod +x run.sh

View File

@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
- [ ] Add tests - [ ] Add tests
- [ ] Add shutdown logic in router and server - [ ] Add shutdown logic in router and server
- [ ] Improve multi-processing logic in server - [ ] Improve multi-processing logic in server
- [ ] Improve error handling everywhere
- [ ] Improve past key layer indexing? - [ ] Improve past key layer indexing?

View File

@ -8,7 +8,7 @@ environment_variables:
MODEL_NAME: bigscience/bloom MODEL_NAME: bigscience/bloom
NUM_GPUS: 8 NUM_GPUS: 8
environment: environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1 image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
inference_config: inference_config:
liveness_route: liveness_route:
port: 3000 port: 3000
@ -24,15 +24,15 @@ request_settings:
request_timeout_ms: 90000 request_timeout_ms: 90000
max_concurrent_requests_per_instance: 256 max_concurrent_requests_per_instance: 256
liveness_probe: liveness_probe:
initial_delay: 300 initial_delay: 600
timeout: 20 timeout: 20
period: 60 period: 120
success_threshold: 1 success_threshold: 1
failure_threshold: 60 failure_threshold: 3
readiness_probe: readiness_probe:
initial_delay: 300 initial_delay: 600
timeout: 20 timeout: 20
period: 60 period: 120
success_threshold: 1 success_threshold: 1
failure_threshold: 60 failure_threshold: 3
instance_count: 1 instance_count: 1

33
router/Cargo.lock generated
View File

@ -149,22 +149,6 @@ dependencies = [
"generic-array", "generic-array",
] ]
[[package]]
name = "bloom-inference"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"serde",
"serde_json",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "bloom-inference-client" name = "bloom-inference-client"
version = "0.1.0" version = "0.1.0"
@ -1669,6 +1653,23 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"serde",
"serde_json",
"thiserror",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]] [[package]]
name = "textwrap" name = "textwrap"
version = "0.11.0" version = "0.11.0"

View File

@ -1,8 +1,15 @@
[package] [package]
name = "bloom-inference" name = "text-generation-router"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies] [dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] } axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" } bloom-inference-client = { path = "client" }
@ -10,6 +17,7 @@ futures = "0.3.24"
parking_lot = "0.12.1" parking_lot = "0.12.1"
serde = "1.0.145" serde = "1.0.145"
serde_json = "1.0.85" serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0" tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] } tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
tracing = "0.1.36" tracing = "0.1.36"

View File

@ -1,45 +1,37 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*; use crate::pb::generate::v1::*;
use crate::Result; use crate::Result;
use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tower::timeout::Timeout;
use tracing::*; use tracing::*;
/// BLOOM Inference gRPC client /// BLOOM Inference gRPC client
#[derive(Clone)] #[derive(Clone)]
pub struct Client { pub struct Client {
stub: TextGenerationServiceClient<Timeout<Channel>>, stub: TextGenerationServiceClient<Channel>,
} }
impl Client { impl Client {
/// Returns a client connected to the given url. Requests exceeding timeout will fail. /// Returns a client connected to the given url
pub async fn connect(uri: Uri, timeout: Duration) -> Self { pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri) let channel = Channel::builder(uri).connect().await?;
.connect()
.await
.expect("Transport error");
let timeout_channel = Timeout::new(channel, timeout);
Self { Ok(Self {
stub: TextGenerationServiceClient::new(timeout_channel), stub: TextGenerationServiceClient::new(channel),
} })
} }
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail. /// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String, timeout: Duration) -> Self { pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string()) let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap() .unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| { .connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone()) tokio::net::UnixStream::connect(path.clone())
})) }))
.await .await?;
.expect("Transport error");
let timeout_channel = Timeout::new(channel, timeout);
Self { Ok(Self {
stub: TextGenerationServiceClient::new(timeout_channel), stub: TextGenerationServiceClient::new(channel),
} })
} }
#[instrument(skip(self))] #[instrument(skip(self))]

View File

@ -8,22 +8,26 @@ pub use client::Client;
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request}; pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
pub use tonic::transport::Uri; pub use tonic::transport;
use tonic::Status; use tonic::Status;
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
#[error("Text generation client error: {msg:?}")] pub enum ClientError {
pub struct ClientError { #[error("Could not connect to Text Generation server: {0:?}")]
msg: String, Connection(String),
// source: Status, #[error("Server error: {0:?}")]
Generation(String),
} }
impl From<Status> for ClientError { impl From<Status> for ClientError {
fn from(err: Status) -> Self { fn from(err: Status) -> Self {
Self { Self::Generation(err.to_string())
msg: err.to_string(), }
// source: err, }
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
Self::Connection(err.to_string())
} }
} }

View File

@ -1,7 +1,6 @@
use crate::Result; use crate::Result;
use crate::{Batch, Client, GeneratedText}; use crate::{Batch, Client, GeneratedText};
use futures::future::join_all; use futures::future::join_all;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri; use tonic::transport::Uri;
@ -69,24 +68,22 @@ impl ShardedClient {
Self { request_tx } Self { request_tx }
} }
async fn from_master_client(mut master_client: Client) -> Self { async fn from_master_client(mut master_client: Client) -> Result<Self> {
let uris = master_client.service_discovery().await.unwrap(); let uris = master_client.service_discovery().await.unwrap();
let futures = uris let futures = uris.into_iter().map(|path| Client::connect_uds(path));
.into_iter() let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
.map(|path| Client::connect_uds(path, Duration::from_secs(5))); Ok(Self::new(clients?))
let clients = join_all(futures).await;
Self::new(clients)
} }
/// Returns a client connected to the given url. Requests exceeding timeout will fail. /// Returns a client connected to the given url
pub async fn connect(uri: Uri, timeout: Duration) -> Self { pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri, timeout).await; let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
} }
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail. /// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String, timeout: Duration) -> Self { pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path, timeout).await; let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await Self::from_master_client(master_client).await
} }

View File

@ -1,13 +1,30 @@
use crate::server::GenerateRequest; use crate::server::GenerateRequest;
use crate::Db; use crate::{Db, Entry};
use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{oneshot, Notify}; use tokio::sync::{oneshot, Notify};
const MAX_LENGTH: usize = 128; const MAX_LENGTH: usize = 128;
pub struct InferError {} #[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded,
}
impl From<InferError> for (StatusCode, String) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct Batcher { pub(crate) struct Batcher {
@ -37,14 +54,18 @@ impl Batcher {
request: GenerateRequest, request: GenerateRequest,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
if self.db.len() > MAX_LENGTH { if self.db.len() > MAX_LENGTH {
return Err(InferError {}); return Err(InferError::Overloaded);
} }
let (request_tx, request_rx) = oneshot::channel(); let (request_tx, request_rx) = oneshot::channel();
self.db.append(input_length, request, request_tx); self.db.append(Entry {
request,
response_tx: request_tx,
input_length,
});
self.shared.batching_task.notify_waiters(); self.shared.batching_task.notify_waiters();
match request_rx.await.unwrap() { match request_rx.await.unwrap() {
Ok(output) => Ok(output), Ok(output) => Ok(output),
Err(_) => Err(InferError {}), Err(err) => Err(InferError::GenerationError(err.to_string())),
} }
} }
} }
@ -108,7 +129,6 @@ async fn wrap_future(
next_batch next_batch
} }
Err(err) => { Err(err) => {
println!("{:?}", err);
send_error(err, request_ids, db); send_error(err, request_ids, db);
None None
} }
@ -117,14 +137,18 @@ async fn wrap_future(
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) { fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
request_ids.into_iter().for_each(|id| { request_ids.into_iter().for_each(|id| {
let (_, response_tx) = db.remove(&id).unwrap(); let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
response_tx.send(Err(error.clone())).unwrap_or(()); // unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(());
}); });
} }
fn send_generated(finished: Vec<GeneratedText>, db: &Db) { fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
finished.into_iter().for_each(|output| { finished.into_iter().for_each(|output| {
let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap(); let entry = db
response_tx.send(Ok(output.output)).unwrap_or(()); .remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Ok(output.output)).unwrap_or(());
}); });
} }

View File

@ -1,11 +1,29 @@
/// This code is massively inspired by Tokio mini-redis /// This code is massively inspired by Tokio mini-redis
use crate::server::GenerateRequest; use crate::server::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request}; use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::oneshot::Sender; use tokio::sync::oneshot::Sender;
#[derive(Debug)]
pub(crate) struct Entry {
pub request: GenerateRequest,
pub response_tx: Sender<Result<String, ClientError>>,
pub input_length: usize,
}
impl From<GenerateParameters> for LogitsWarperParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
}
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Db { pub(crate) struct Db {
pub shared: Arc<Shared>, pub shared: Arc<Shared>,
@ -18,7 +36,7 @@ pub struct Shared {
#[derive(Debug)] #[derive(Debug)]
struct State { struct State {
entries: BTreeMap<u64, (Request, Sender<Result<String, ClientError>>)>, entries: BTreeMap<u64, Entry>,
/// Identifier to use for the next expiration. Each expiration is associated /// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why. /// with a unique identifier. See above for why.
@ -44,37 +62,16 @@ impl Db {
Self { shared } Self { shared }
} }
pub(crate) fn append( pub(crate) fn append(&self, entry: Entry) {
&self,
input_length: usize,
request: GenerateRequest,
sender: Sender<Result<String, ClientError>>,
) {
let mut state = self.shared.state.write(); let mut state = self.shared.state.write();
let id = state.next_id; let id = state.next_id;
state.next_id += 1; state.next_id += 1;
let parameters = Some(LogitsWarperParameters { state.entries.insert(id, entry);
temperature: request.parameters.temperature,
top_k: request.parameters.top_k,
top_p: request.parameters.top_p,
do_sample: request.parameters.do_sample,
});
let request = Request {
id,
inputs: request.inputs,
input_length: input_length as u32,
parameters,
max_new_tokens: request.parameters.max_new_tokens,
};
state.entries.insert(id, (request, sender));
} }
pub(crate) fn remove( pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
&self,
id: &u64,
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
let mut state = self.shared.state.write(); let mut state = self.shared.state.write();
state.entries.remove(id) state.entries.remove(id)
} }
@ -91,7 +88,15 @@ impl Db {
.entries .entries
.range(state.next_batch_start_id..) .range(state.next_batch_start_id..)
.take(max_size) .take(max_size)
.map(|(_, (request, _))| request.clone()) .map(|(id, entry)| Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
})
.collect(); .collect();
if requests.is_empty() { if requests.is_empty() {

8
router/src/lib.rs Normal file
View File

@ -0,0 +1,8 @@
mod batcher;
mod db;
pub mod server;
mod validation;
use batcher::Batcher;
use db::{Db, Entry};
use validation::Validation;

View File

@ -1,21 +1,8 @@
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod server;
mod validation;
use validation::Validation;
mod db;
use db::Db;
mod batcher;
use batcher::Batcher;
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap(); let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let sharded_client = ShardedClient::connect_uds( let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
"/tmp/bloom-inference-0".to_string(), .await
Duration::from_secs(5), .expect("Could not connect to server");
)
.await;
sharded_client sharded_client
.clear_cache() .clear_cache()
.await .await

View File

@ -1,9 +1,9 @@
use bloom_inference_client::ShardedClient;
use crate::{Batcher, Validation}; use crate::{Batcher, Validation};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use bloom_inference_client::ShardedClient;
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")] #[serde(default = "default_temperature")]
pub temperature: f32, pub temperature: f32,
#[serde(default = "default_top_k")] #[serde(default = "default_top_k")]
pub top_k: u32, pub top_k: i32,
#[serde(default = "default_top_p")] #[serde(default = "default_top_p")]
pub top_p: f32, pub top_p: f32,
#[serde(default = "default_do_sample")] #[serde(default = "default_do_sample")]
@ -28,7 +28,7 @@ fn default_temperature() -> f32 {
1.0 1.0
} }
fn default_top_k() -> u32 { fn default_top_k() -> i32 {
0 0
} }
@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest {
} }
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> { async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
let output = state state
.infer .infer
.infer( .infer(
1, 1,
@ -78,50 +78,37 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
}, },
}, },
) )
.await; .await?;
Ok(())
match output {
Ok(_) => Ok(()),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
} }
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn generate( async fn generate(
state: Extension<ServerState>, state: Extension<ServerState>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let start = Instant::now(); let start = Instant::now();
let (input_length, validated_request) = match state let (input_length, validated_request) = state
.validation .validation
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: req.inputs.clone(), inputs: req.inputs.clone(),
parameters: req.parameters.clone(), parameters: req.parameters.clone(),
}) })
.await .await?;
{
Ok(result) => result,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let output = state.infer.infer(input_length, validated_request).await; let generated_text = state.infer.infer(input_length, validated_request).await?;
match output { tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
Ok(generated_text) => { tracing::Span::current().record(
tracing::Span::current().record("time", format!("{:?}", start.elapsed())); "time_per_token",
tracing::Span::current().record( format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
"time_per_token", );
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens), tracing::info!("response: {}", generated_text);
);
tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"generated_text": generated_text, "generated_text": generated_text,
}))) })))
}
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
} }
#[derive(Clone)] #[derive(Clone)]

View File

@ -1,9 +1,28 @@
use crate::server::GenerateRequest; use crate::server::GenerateRequest;
use axum::http::StatusCode;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
#[derive(Debug)] #[derive(Error, Debug)]
pub struct ValidationError {} pub enum ValidationError {
#[error("Temperature must be strictly positive")]
Temperature,
#[error("Top p must be <= 0.0 or > 1.0")]
TopP,
#[error("Top k must be strictly positive")]
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 512 tokens. Given: {0}")]
InputLength(usize),
}
impl From<ValidationError> for (StatusCode, String) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
}
}
type ValidationRequest = ( type ValidationRequest = (
GenerateRequest, GenerateRequest,
@ -39,15 +58,23 @@ impl Validation {
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) { async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
while let Some((request, response_tx)) = receiver.recv().await { while let Some((request, response_tx)) = receiver.recv().await {
if request.parameters.temperature < 0.0 { if request.parameters.temperature < 0.0 {
response_tx.send(Err(ValidationError {})).unwrap_or(()); response_tx
.send(Err(ValidationError::Temperature))
.unwrap_or(());
continue; continue;
} }
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
response_tx.send(Err(ValidationError {})).unwrap_or(()); response_tx.send(Err(ValidationError::TopP)).unwrap_or(());
continue;
}
if request.parameters.top_k < 0 {
response_tx.send(Err(ValidationError::TopK)).unwrap_or(());
continue; continue;
} }
if request.parameters.max_new_tokens > 512 { if request.parameters.max_new_tokens > 512 {
response_tx.send(Err(ValidationError {})).unwrap_or(()); response_tx
.send(Err(ValidationError::MaxNewTokens))
.unwrap_or(());
continue; continue;
} }
@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
let input_length = inputs.len(); let input_length = inputs.len();
if input_length > 512 { if input_length > 512 {
response_tx.send(Err(ValidationError {})).unwrap_or(()); response_tx
.send(Err(ValidationError::InputLength(input_length)))
.unwrap_or(());
continue; continue;
} }
response_tx.send(Ok((input_length, request))).unwrap_or(()); response_tx.send(Ok((input_length, request))).unwrap_or(());
} }
println!("drop here");
} }

17
run.sh Executable file → Normal file
View File

@ -1,10 +1,12 @@
#!/usr/bin/env bash #!/usr/bin/env bash
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH" server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
$server_cmd &
# Run in background
$server_cmd 2>&1 > /dev/null &
# Check if server is running by checking if the unix socket is created
FILE=/tmp/bloom-inference-0 FILE=/tmp/bloom-inference-0
while : while :
do do
if test -S "$FILE"; then if test -S "$FILE"; then
@ -18,4 +20,11 @@ while :
sleep 1 sleep 1
exec "bloom-inference" # Run in background
text-generation-router &
# Wait for any process to exit
wait -n
# Exit with status of process that exited first
exit $?

View File

@ -0,0 +1,42 @@
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference import server
app = typer.Typer()
@app.command()
def launcher(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, server.serve, [model_name, True, shard_directory])
@app.command()
def serve(
model_name: str,
sharded: bool = False,
shard_directory: Optional[Path] = None,
):
server.serve(model_name, sharded, shard_directory)
if __name__ == "__main__":
app()

View File

@ -1,30 +0,0 @@
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference.server import serve
def main(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, serve, [model_name, True, shard_directory])
if __name__ == "__main__":
typer.run(main)

View File

@ -1,4 +1,6 @@
import asyncio import asyncio
import os
from grpc import aio from grpc import aio
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory):
await server.wait_for_termination() await server.wait_for_termination()
asyncio.run(serve_inner(model_name, sharded, shard_directory)) asyncio.run(serve_inner(model_name, sharded, shard_directory))
if __name__ == "__main__":
serve("bigscience/bloom-560m", True, Path("/tmp/models"))

View File

@ -2,6 +2,8 @@ import os
import contextlib import contextlib
import torch import torch
import torch.distributed import torch.distributed
from datetime import timedelta
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
LogitsProcessorList, LogitsProcessorList,
TemperatureLogitsWarper, TemperatureLogitsWarper,
@ -79,6 +81,7 @@ def initialize_torch_distributed():
backend=backend, backend=backend,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timedelta(seconds=60),
init_method="tcp://localhost:6000", init_method="tcp://localhost:6000",
) )

View File

@ -4,6 +4,9 @@ version = "0.1.0"
description = "BLOOM Inference Python gRPC Server" description = "BLOOM Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
bloom-inference-server = 'bloom_inference.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
protobuf = "^4.21.7" protobuf = "^4.21.7"