feat: Improve error handling
This commit is contained in:
parent
00e6ce44b1
commit
5e5d8766a2
|
@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \
|
||||||
MODEL_NAME=bigscience/bloom \
|
MODEL_NAME=bigscience/bloom \
|
||||||
NUM_GPUS=8 \
|
NUM_GPUS=8 \
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
NCCL_ASYNC_ERROR_HANDLING=1 \
|
||||||
CUDA_HOME=/usr/local/cuda \
|
CUDA_HOME=/usr/local/cuda \
|
||||||
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||||
CONDA_DEFAULT_ENV=text-generation \
|
CONDA_DEFAULT_ENV=text-generation \
|
||||||
|
@ -51,7 +52,7 @@ RUN cd server && \
|
||||||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||||
|
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
|
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
|
||||||
COPY run.sh .
|
COPY run.sh .
|
||||||
RUN chmod +x run.sh
|
RUN chmod +x run.sh
|
||||||
|
|
|
@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
|
||||||
- [ ] Add tests
|
- [ ] Add tests
|
||||||
- [ ] Add shutdown logic in router and server
|
- [ ] Add shutdown logic in router and server
|
||||||
- [ ] Improve multi-processing logic in server
|
- [ ] Improve multi-processing logic in server
|
||||||
- [ ] Improve error handling everywhere
|
|
||||||
- [ ] Improve past key layer indexing?
|
- [ ] Improve past key layer indexing?
|
|
@ -8,7 +8,7 @@ environment_variables:
|
||||||
MODEL_NAME: bigscience/bloom
|
MODEL_NAME: bigscience/bloom
|
||||||
NUM_GPUS: 8
|
NUM_GPUS: 8
|
||||||
environment:
|
environment:
|
||||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
|
||||||
inference_config:
|
inference_config:
|
||||||
liveness_route:
|
liveness_route:
|
||||||
port: 3000
|
port: 3000
|
||||||
|
@ -24,15 +24,15 @@ request_settings:
|
||||||
request_timeout_ms: 90000
|
request_timeout_ms: 90000
|
||||||
max_concurrent_requests_per_instance: 256
|
max_concurrent_requests_per_instance: 256
|
||||||
liveness_probe:
|
liveness_probe:
|
||||||
initial_delay: 300
|
initial_delay: 600
|
||||||
timeout: 20
|
timeout: 20
|
||||||
period: 60
|
period: 120
|
||||||
success_threshold: 1
|
success_threshold: 1
|
||||||
failure_threshold: 60
|
failure_threshold: 3
|
||||||
readiness_probe:
|
readiness_probe:
|
||||||
initial_delay: 300
|
initial_delay: 600
|
||||||
timeout: 20
|
timeout: 20
|
||||||
period: 60
|
period: 120
|
||||||
success_threshold: 1
|
success_threshold: 1
|
||||||
failure_threshold: 60
|
failure_threshold: 3
|
||||||
instance_count: 1
|
instance_count: 1
|
||||||
|
|
|
@ -149,22 +149,6 @@ dependencies = [
|
||||||
"generic-array",
|
"generic-array",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "bloom-inference"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"axum",
|
|
||||||
"bloom-inference-client",
|
|
||||||
"futures",
|
|
||||||
"parking_lot",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"tokenizers",
|
|
||||||
"tokio",
|
|
||||||
"tracing",
|
|
||||||
"tracing-subscriber",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bloom-inference-client"
|
name = "bloom-inference-client"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
@ -1669,6 +1653,23 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "text-generation-router"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"bloom-inference-client",
|
||||||
|
"futures",
|
||||||
|
"parking_lot",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
|
"tokenizers",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "textwrap"
|
name = "textwrap"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
[package]
|
[package]
|
||||||
name = "bloom-inference"
|
name = "text-generation-router"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "text-generation-router"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||||
bloom-inference-client = { path = "client" }
|
bloom-inference-client = { path = "client" }
|
||||||
|
@ -10,6 +17,7 @@ futures = "0.3.24"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
serde = "1.0.145"
|
serde = "1.0.145"
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
|
thiserror = "1.0.37"
|
||||||
tokenizers = "0.13.0"
|
tokenizers = "0.13.0"
|
||||||
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
|
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
|
||||||
tracing = "0.1.36"
|
tracing = "0.1.36"
|
||||||
|
|
|
@ -1,45 +1,37 @@
|
||||||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use crate::pb::generate::v1::*;
|
use crate::pb::generate::v1::*;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use std::time::Duration;
|
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tower::timeout::Timeout;
|
|
||||||
use tracing::*;
|
use tracing::*;
|
||||||
|
|
||||||
/// BLOOM Inference gRPC client
|
/// BLOOM Inference gRPC client
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
stub: TextGenerationServiceClient<Timeout<Channel>>,
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
|
/// Returns a client connected to the given url
|
||||||
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
let channel = Channel::builder(uri)
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
.connect()
|
|
||||||
.await
|
|
||||||
.expect("Transport error");
|
|
||||||
let timeout_channel = Timeout::new(channel, timeout);
|
|
||||||
|
|
||||||
Self {
|
Ok(Self {
|
||||||
stub: TextGenerationServiceClient::new(timeout_channel),
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
|
/// Returns a client connected to the given unix socket
|
||||||
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||||
tokio::net::UnixStream::connect(path.clone())
|
tokio::net::UnixStream::connect(path.clone())
|
||||||
}))
|
}))
|
||||||
.await
|
.await?;
|
||||||
.expect("Transport error");
|
|
||||||
let timeout_channel = Timeout::new(channel, timeout);
|
|
||||||
|
|
||||||
Self {
|
Ok(Self {
|
||||||
stub: TextGenerationServiceClient::new(timeout_channel),
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
|
|
|
@ -8,22 +8,26 @@ pub use client::Client;
|
||||||
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
pub use tonic::transport::Uri;
|
pub use tonic::transport;
|
||||||
use tonic::Status;
|
use tonic::Status;
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
#[error("Text generation client error: {msg:?}")]
|
pub enum ClientError {
|
||||||
pub struct ClientError {
|
#[error("Could not connect to Text Generation server: {0:?}")]
|
||||||
msg: String,
|
Connection(String),
|
||||||
// source: Status,
|
#[error("Server error: {0:?}")]
|
||||||
|
Generation(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Status> for ClientError {
|
impl From<Status> for ClientError {
|
||||||
fn from(err: Status) -> Self {
|
fn from(err: Status) -> Self {
|
||||||
Self {
|
Self::Generation(err.to_string())
|
||||||
msg: err.to_string(),
|
|
||||||
// source: err,
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<transport::Error> for ClientError {
|
||||||
|
fn from(err: transport::Error) -> Self {
|
||||||
|
Self::Connection(err.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, GeneratedText};
|
use crate::{Batch, Client, GeneratedText};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
|
|
||||||
|
@ -69,24 +68,22 @@ impl ShardedClient {
|
||||||
Self { request_tx }
|
Self { request_tx }
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn from_master_client(mut master_client: Client) -> Self {
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
let uris = master_client.service_discovery().await.unwrap();
|
let uris = master_client.service_discovery().await.unwrap();
|
||||||
let futures = uris
|
let futures = uris.into_iter().map(|path| Client::connect_uds(path));
|
||||||
.into_iter()
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
.map(|path| Client::connect_uds(path, Duration::from_secs(5)));
|
Ok(Self::new(clients?))
|
||||||
let clients = join_all(futures).await;
|
|
||||||
Self::new(clients)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
|
/// Returns a client connected to the given url
|
||||||
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
let master_client = Client::connect(uri, timeout).await;
|
let master_client = Client::connect(uri).await?;
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
|
/// Returns a client connected to the given unix socket
|
||||||
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
let master_client = Client::connect_uds(path, timeout).await;
|
let master_client = Client::connect_uds(path).await?;
|
||||||
Self::from_master_client(master_client).await
|
Self::from_master_client(master_client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,30 @@
|
||||||
use crate::server::GenerateRequest;
|
use crate::server::GenerateRequest;
|
||||||
use crate::Db;
|
use crate::{Db, Entry};
|
||||||
|
use axum::http::StatusCode;
|
||||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use thiserror::Error;
|
||||||
use tokio::sync::{oneshot, Notify};
|
use tokio::sync::{oneshot, Notify};
|
||||||
|
|
||||||
const MAX_LENGTH: usize = 128;
|
const MAX_LENGTH: usize = 128;
|
||||||
|
|
||||||
pub struct InferError {}
|
#[derive(Debug, Error)]
|
||||||
|
pub enum InferError {
|
||||||
|
#[error("Request failed during generation: {0}")]
|
||||||
|
GenerationError(String),
|
||||||
|
#[error("Model is overloaded")]
|
||||||
|
Overloaded,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InferError> for (StatusCode, String) {
|
||||||
|
fn from(err: InferError) -> Self {
|
||||||
|
match err {
|
||||||
|
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
||||||
|
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct Batcher {
|
pub(crate) struct Batcher {
|
||||||
|
@ -37,14 +54,18 @@ impl Batcher {
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
if self.db.len() > MAX_LENGTH {
|
if self.db.len() > MAX_LENGTH {
|
||||||
return Err(InferError {});
|
return Err(InferError::Overloaded);
|
||||||
}
|
}
|
||||||
let (request_tx, request_rx) = oneshot::channel();
|
let (request_tx, request_rx) = oneshot::channel();
|
||||||
self.db.append(input_length, request, request_tx);
|
self.db.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx: request_tx,
|
||||||
|
input_length,
|
||||||
|
});
|
||||||
self.shared.batching_task.notify_waiters();
|
self.shared.batching_task.notify_waiters();
|
||||||
match request_rx.await.unwrap() {
|
match request_rx.await.unwrap() {
|
||||||
Ok(output) => Ok(output),
|
Ok(output) => Ok(output),
|
||||||
Err(_) => Err(InferError {}),
|
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,7 +129,6 @@ async fn wrap_future(
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
println!("{:?}", err);
|
|
||||||
send_error(err, request_ids, db);
|
send_error(err, request_ids, db);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
@ -117,14 +137,18 @@ async fn wrap_future(
|
||||||
|
|
||||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
||||||
request_ids.into_iter().for_each(|id| {
|
request_ids.into_iter().for_each(|id| {
|
||||||
let (_, response_tx) = db.remove(&id).unwrap();
|
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
||||||
response_tx.send(Err(error.clone())).unwrap_or(());
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||||
finished.into_iter().for_each(|output| {
|
finished.into_iter().for_each(|output| {
|
||||||
let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap();
|
let entry = db
|
||||||
response_tx.send(Ok(output.output)).unwrap_or(());
|
.remove(&output.request.unwrap().id)
|
||||||
|
.expect("ID not found in db. This is a bug.");
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,29 @@
|
||||||
/// This code is massively inspired by Tokio mini-redis
|
/// This code is massively inspired by Tokio mini-redis
|
||||||
use crate::server::GenerateRequest;
|
use crate::server::{GenerateParameters, GenerateRequest};
|
||||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||||
use parking_lot::RwLock;
|
use parking_lot::RwLock;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::oneshot::Sender;
|
use tokio::sync::oneshot::Sender;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct Entry {
|
||||||
|
pub request: GenerateRequest,
|
||||||
|
pub response_tx: Sender<Result<String, ClientError>>,
|
||||||
|
pub input_length: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<GenerateParameters> for LogitsWarperParameters {
|
||||||
|
fn from(parameters: GenerateParameters) -> Self {
|
||||||
|
Self {
|
||||||
|
temperature: parameters.temperature,
|
||||||
|
top_k: parameters.top_k as u32,
|
||||||
|
top_p: parameters.top_p,
|
||||||
|
do_sample: parameters.do_sample,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Db {
|
pub(crate) struct Db {
|
||||||
pub shared: Arc<Shared>,
|
pub shared: Arc<Shared>,
|
||||||
|
@ -18,7 +36,7 @@ pub struct Shared {
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct State {
|
struct State {
|
||||||
entries: BTreeMap<u64, (Request, Sender<Result<String, ClientError>>)>,
|
entries: BTreeMap<u64, Entry>,
|
||||||
|
|
||||||
/// Identifier to use for the next expiration. Each expiration is associated
|
/// Identifier to use for the next expiration. Each expiration is associated
|
||||||
/// with a unique identifier. See above for why.
|
/// with a unique identifier. See above for why.
|
||||||
|
@ -44,37 +62,16 @@ impl Db {
|
||||||
Self { shared }
|
Self { shared }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn append(
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
&self,
|
|
||||||
input_length: usize,
|
|
||||||
request: GenerateRequest,
|
|
||||||
sender: Sender<Result<String, ClientError>>,
|
|
||||||
) {
|
|
||||||
let mut state = self.shared.state.write();
|
let mut state = self.shared.state.write();
|
||||||
|
|
||||||
let id = state.next_id;
|
let id = state.next_id;
|
||||||
state.next_id += 1;
|
state.next_id += 1;
|
||||||
|
|
||||||
let parameters = Some(LogitsWarperParameters {
|
state.entries.insert(id, entry);
|
||||||
temperature: request.parameters.temperature,
|
|
||||||
top_k: request.parameters.top_k,
|
|
||||||
top_p: request.parameters.top_p,
|
|
||||||
do_sample: request.parameters.do_sample,
|
|
||||||
});
|
|
||||||
let request = Request {
|
|
||||||
id,
|
|
||||||
inputs: request.inputs,
|
|
||||||
input_length: input_length as u32,
|
|
||||||
parameters,
|
|
||||||
max_new_tokens: request.parameters.max_new_tokens,
|
|
||||||
};
|
|
||||||
state.entries.insert(id, (request, sender));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn remove(
|
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
||||||
&self,
|
|
||||||
id: &u64,
|
|
||||||
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
|
|
||||||
let mut state = self.shared.state.write();
|
let mut state = self.shared.state.write();
|
||||||
state.entries.remove(id)
|
state.entries.remove(id)
|
||||||
}
|
}
|
||||||
|
@ -91,7 +88,15 @@ impl Db {
|
||||||
.entries
|
.entries
|
||||||
.range(state.next_batch_start_id..)
|
.range(state.next_batch_start_id..)
|
||||||
.take(max_size)
|
.take(max_size)
|
||||||
.map(|(_, (request, _))| request.clone())
|
.map(|(id, entry)| Request {
|
||||||
|
id: *id,
|
||||||
|
inputs: entry.request.inputs.clone(),
|
||||||
|
input_length: entry.input_length as u32,
|
||||||
|
parameters: Some(LogitsWarperParameters::from(
|
||||||
|
entry.request.parameters.clone(),
|
||||||
|
)),
|
||||||
|
max_new_tokens: entry.request.parameters.max_new_tokens,
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if requests.is_empty() {
|
if requests.is_empty() {
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
mod batcher;
|
||||||
|
mod db;
|
||||||
|
pub mod server;
|
||||||
|
mod validation;
|
||||||
|
|
||||||
|
use batcher::Batcher;
|
||||||
|
use db::{Db, Entry};
|
||||||
|
use validation::Validation;
|
|
@ -1,21 +1,8 @@
|
||||||
use bloom_inference_client::ShardedClient;
|
use bloom_inference_client::ShardedClient;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::time::Duration;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod server;
|
|
||||||
mod validation;
|
|
||||||
|
|
||||||
use validation::Validation;
|
|
||||||
|
|
||||||
mod db;
|
|
||||||
|
|
||||||
use db::Db;
|
|
||||||
|
|
||||||
mod batcher;
|
|
||||||
|
|
||||||
use batcher::Batcher;
|
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), std::io::Error> {
|
||||||
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
|
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
|
||||||
|
|
||||||
|
@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
let sharded_client = ShardedClient::connect_uds(
|
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
|
||||||
"/tmp/bloom-inference-0".to_string(),
|
.await
|
||||||
Duration::from_secs(5),
|
.expect("Could not connect to server");
|
||||||
)
|
|
||||||
.await;
|
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache()
|
.clear_cache()
|
||||||
.await
|
.await
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use bloom_inference_client::ShardedClient;
|
|
||||||
use crate::{Batcher, Validation};
|
use crate::{Batcher, Validation};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::StatusCode;
|
use axum::http::StatusCode;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
|
use bloom_inference_client::ShardedClient;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters {
|
||||||
#[serde(default = "default_temperature")]
|
#[serde(default = "default_temperature")]
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
#[serde(default = "default_top_k")]
|
#[serde(default = "default_top_k")]
|
||||||
pub top_k: u32,
|
pub top_k: i32,
|
||||||
#[serde(default = "default_top_p")]
|
#[serde(default = "default_top_p")]
|
||||||
pub top_p: f32,
|
pub top_p: f32,
|
||||||
#[serde(default = "default_do_sample")]
|
#[serde(default = "default_do_sample")]
|
||||||
|
@ -28,7 +28,7 @@ fn default_temperature() -> f32 {
|
||||||
1.0
|
1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_top_k() -> u32 {
|
fn default_top_k() -> i32 {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
|
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
||||||
let output = state
|
state
|
||||||
.infer
|
.infer
|
||||||
.infer(
|
.infer(
|
||||||
1,
|
1,
|
||||||
|
@ -78,37 +78,27 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
|
Ok(())
|
||||||
match output {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument(skip(state), fields(time, time_per_token))]
|
#[instrument(skip(state), fields(time, time_per_token))]
|
||||||
async fn generate(
|
async fn generate(
|
||||||
state: Extension<ServerState>,
|
state: Extension<ServerState>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let (input_length, validated_request) = match state
|
let (input_length, validated_request) = state
|
||||||
.validation
|
.validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: req.inputs.clone(),
|
inputs: req.inputs.clone(),
|
||||||
parameters: req.parameters.clone(),
|
parameters: req.parameters.clone(),
|
||||||
})
|
})
|
||||||
.await
|
.await?;
|
||||||
{
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = state.infer.infer(input_length, validated_request).await;
|
let generated_text = state.infer.infer(input_length, validated_request).await?;
|
||||||
|
|
||||||
match output {
|
|
||||||
Ok(generated_text) => {
|
|
||||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||||
tracing::Span::current().record(
|
tracing::Span::current().record(
|
||||||
"time_per_token",
|
"time_per_token",
|
||||||
|
@ -119,9 +109,6 @@ async fn generate(
|
||||||
Ok(Json(serde_json::json!({
|
Ok(Json(serde_json::json!({
|
||||||
"generated_text": generated_text,
|
"generated_text": generated_text,
|
||||||
})))
|
})))
|
||||||
}
|
|
||||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
|
@ -1,9 +1,28 @@
|
||||||
use crate::server::GenerateRequest;
|
use crate::server::GenerateRequest;
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub struct ValidationError {}
|
pub enum ValidationError {
|
||||||
|
#[error("Temperature must be strictly positive")]
|
||||||
|
Temperature,
|
||||||
|
#[error("Top p must be <= 0.0 or > 1.0")]
|
||||||
|
TopP,
|
||||||
|
#[error("Top k must be strictly positive")]
|
||||||
|
TopK,
|
||||||
|
#[error("Max New Tokens must be < 512")]
|
||||||
|
MaxNewTokens,
|
||||||
|
#[error("Inputs must have less than 512 tokens. Given: {0}")]
|
||||||
|
InputLength(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ValidationError> for (StatusCode, String) {
|
||||||
|
fn from(err: ValidationError) -> Self {
|
||||||
|
(StatusCode::BAD_REQUEST, err.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ValidationRequest = (
|
type ValidationRequest = (
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
|
@ -39,15 +58,23 @@ impl Validation {
|
||||||
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
|
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
|
||||||
while let Some((request, response_tx)) = receiver.recv().await {
|
while let Some((request, response_tx)) = receiver.recv().await {
|
||||||
if request.parameters.temperature < 0.0 {
|
if request.parameters.temperature < 0.0 {
|
||||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
response_tx
|
||||||
|
.send(Err(ValidationError::Temperature))
|
||||||
|
.unwrap_or(());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
||||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
response_tx.send(Err(ValidationError::TopP)).unwrap_or(());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if request.parameters.top_k < 0 {
|
||||||
|
response_tx.send(Err(ValidationError::TopK)).unwrap_or(());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if request.parameters.max_new_tokens > 512 {
|
if request.parameters.max_new_tokens > 512 {
|
||||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
response_tx
|
||||||
|
.send(Err(ValidationError::MaxNewTokens))
|
||||||
|
.unwrap_or(());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
||||||
let input_length = inputs.len();
|
let input_length = inputs.len();
|
||||||
|
|
||||||
if input_length > 512 {
|
if input_length > 512 {
|
||||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
response_tx
|
||||||
|
.send(Err(ValidationError::InputLength(input_length)))
|
||||||
|
.unwrap_or(());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
response_tx.send(Ok((input_length, request))).unwrap_or(());
|
response_tx.send(Ok((input_length, request))).unwrap_or(());
|
||||||
}
|
}
|
||||||
println!("drop here");
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
||||||
$server_cmd &
|
|
||||||
|
|
||||||
|
# Run in background
|
||||||
|
$server_cmd 2>&1 > /dev/null &
|
||||||
|
|
||||||
|
# Check if server is running by checking if the unix socket is created
|
||||||
FILE=/tmp/bloom-inference-0
|
FILE=/tmp/bloom-inference-0
|
||||||
|
|
||||||
while :
|
while :
|
||||||
do
|
do
|
||||||
if test -S "$FILE"; then
|
if test -S "$FILE"; then
|
||||||
|
@ -18,4 +20,11 @@ while :
|
||||||
|
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
exec "bloom-inference"
|
# Run in background
|
||||||
|
text-generation-router &
|
||||||
|
|
||||||
|
# Wait for any process to exit
|
||||||
|
wait -n
|
||||||
|
|
||||||
|
# Exit with status of process that exited first
|
||||||
|
exit $?
|
|
@ -0,0 +1,42 @@
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from torch.distributed.launcher import launch_agent, LaunchConfig
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from bloom_inference import server
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def launcher(
|
||||||
|
model_name: str,
|
||||||
|
num_gpus: int = 1,
|
||||||
|
shard_directory: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
if num_gpus == 1:
|
||||||
|
serve(model_name, False, shard_directory)
|
||||||
|
|
||||||
|
else:
|
||||||
|
config = LaunchConfig(
|
||||||
|
min_nodes=1,
|
||||||
|
max_nodes=1,
|
||||||
|
nproc_per_node=num_gpus,
|
||||||
|
rdzv_backend="c10d",
|
||||||
|
max_restarts=0,
|
||||||
|
)
|
||||||
|
launch_agent(config, server.serve, [model_name, True, shard_directory])
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def serve(
|
||||||
|
model_name: str,
|
||||||
|
sharded: bool = False,
|
||||||
|
shard_directory: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
server.serve(model_name, sharded, shard_directory)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
|
@ -1,30 +0,0 @@
|
||||||
import typer
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from torch.distributed.launcher import launch_agent, LaunchConfig
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from bloom_inference.server import serve
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
|
||||||
model_name: str,
|
|
||||||
num_gpus: int = 1,
|
|
||||||
shard_directory: Optional[Path] = None,
|
|
||||||
):
|
|
||||||
if num_gpus == 1:
|
|
||||||
serve(model_name, False, shard_directory)
|
|
||||||
|
|
||||||
else:
|
|
||||||
config = LaunchConfig(
|
|
||||||
min_nodes=1,
|
|
||||||
max_nodes=1,
|
|
||||||
nproc_per_node=num_gpus,
|
|
||||||
rdzv_backend="c10d",
|
|
||||||
max_restarts=0,
|
|
||||||
)
|
|
||||||
launch_agent(config, serve, [model_name, True, shard_directory])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
typer.run(main)
|
|
|
@ -1,4 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
from grpc import aio
|
from grpc import aio
|
||||||
|
|
||||||
from grpc_reflection.v1alpha import reflection
|
from grpc_reflection.v1alpha import reflection
|
||||||
|
@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory):
|
||||||
await server.wait_for_termination()
|
await server.wait_for_termination()
|
||||||
|
|
||||||
asyncio.run(serve_inner(model_name, sharded, shard_directory))
|
asyncio.run(serve_inner(model_name, sharded, shard_directory))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
serve("bigscience/bloom-560m", True, Path("/tmp/models"))
|
|
||||||
|
|
|
@ -2,6 +2,8 @@ import os
|
||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
|
@ -79,6 +81,7 @@ def initialize_torch_distributed():
|
||||||
backend=backend,
|
backend=backend,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
init_method="tcp://localhost:6000",
|
init_method="tcp://localhost:6000",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,9 @@ version = "0.1.0"
|
||||||
description = "BLOOM Inference Python gRPC Server"
|
description = "BLOOM Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
bloom-inference-server = 'bloom_inference.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
protobuf = "^4.21.7"
|
protobuf = "^4.21.7"
|
||||||
|
|
Loading…
Reference in New Issue