feat: Improve error handling
This commit is contained in:
parent
00e6ce44b1
commit
5e5d8766a2
|
@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \
|
|||
MODEL_NAME=bigscience/bloom \
|
||||
NUM_GPUS=8 \
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
NCCL_ASYNC_ERROR_HANDLING=1 \
|
||||
CUDA_HOME=/usr/local/cuda \
|
||||
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
|
||||
CONDA_DEFAULT_ENV=text-generation \
|
||||
|
@ -51,7 +52,7 @@ RUN cd server && \
|
|||
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
|
||||
|
||||
# Install router
|
||||
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
|
||||
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
|
||||
|
||||
COPY run.sh .
|
||||
RUN chmod +x run.sh
|
||||
|
|
|
@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
|
|||
- [ ] Add tests
|
||||
- [ ] Add shutdown logic in router and server
|
||||
- [ ] Improve multi-processing logic in server
|
||||
- [ ] Improve error handling everywhere
|
||||
- [ ] Improve past key layer indexing?
|
|
@ -8,7 +8,7 @@ environment_variables:
|
|||
MODEL_NAME: bigscience/bloom
|
||||
NUM_GPUS: 8
|
||||
environment:
|
||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
|
||||
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
|
||||
inference_config:
|
||||
liveness_route:
|
||||
port: 3000
|
||||
|
@ -24,15 +24,15 @@ request_settings:
|
|||
request_timeout_ms: 90000
|
||||
max_concurrent_requests_per_instance: 256
|
||||
liveness_probe:
|
||||
initial_delay: 300
|
||||
initial_delay: 600
|
||||
timeout: 20
|
||||
period: 60
|
||||
period: 120
|
||||
success_threshold: 1
|
||||
failure_threshold: 60
|
||||
failure_threshold: 3
|
||||
readiness_probe:
|
||||
initial_delay: 300
|
||||
initial_delay: 600
|
||||
timeout: 20
|
||||
period: 60
|
||||
period: 120
|
||||
success_threshold: 1
|
||||
failure_threshold: 60
|
||||
failure_threshold: 3
|
||||
instance_count: 1
|
||||
|
|
|
@ -149,22 +149,6 @@ dependencies = [
|
|||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bloom-inference"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"bloom-inference-client",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bloom-inference-client"
|
||||
version = "0.1.0"
|
||||
|
@ -1669,6 +1653,23 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"bloom-inference-client",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
[package]
|
||||
name = "bloom-inference"
|
||||
name = "text-generation-router"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||
bloom-inference-client = { path = "client" }
|
||||
|
@ -10,6 +17,7 @@ futures = "0.3.24"
|
|||
parking_lot = "0.12.1"
|
||||
serde = "1.0.145"
|
||||
serde_json = "1.0.85"
|
||||
thiserror = "1.0.37"
|
||||
tokenizers = "0.13.0"
|
||||
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
|
||||
tracing = "0.1.36"
|
||||
|
|
|
@ -1,45 +1,37 @@
|
|||
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
|
||||
use crate::pb::generate::v1::*;
|
||||
use crate::Result;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tower::timeout::Timeout;
|
||||
use tracing::*;
|
||||
|
||||
/// BLOOM Inference gRPC client
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Timeout<Channel>>,
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
|
||||
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
|
||||
let channel = Channel::builder(uri)
|
||||
.connect()
|
||||
.await
|
||||
.expect("Transport error");
|
||||
let timeout_channel = Timeout::new(channel, timeout);
|
||||
/// Returns a client connected to the given url
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Self {
|
||||
stub: TextGenerationServiceClient::new(timeout_channel),
|
||||
}
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
|
||||
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await
|
||||
.expect("Transport error");
|
||||
let timeout_channel = Timeout::new(channel, timeout);
|
||||
.await?;
|
||||
|
||||
Self {
|
||||
stub: TextGenerationServiceClient::new(timeout_channel),
|
||||
}
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
|
|
|
@ -8,22 +8,26 @@ pub use client::Client;
|
|||
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
|
||||
pub use sharded_client::ShardedClient;
|
||||
use thiserror::Error;
|
||||
pub use tonic::transport::Uri;
|
||||
pub use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
#[error("Text generation client error: {msg:?}")]
|
||||
pub struct ClientError {
|
||||
msg: String,
|
||||
// source: Status,
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0:?}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0:?}")]
|
||||
Generation(String),
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
Self {
|
||||
msg: err.to_string(),
|
||||
// source: err,
|
||||
}
|
||||
Self::Generation(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
Self::Connection(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::Result;
|
||||
use crate::{Batch, Client, GeneratedText};
|
||||
use futures::future::join_all;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tonic::transport::Uri;
|
||||
|
||||
|
@ -69,24 +68,22 @@ impl ShardedClient {
|
|||
Self { request_tx }
|
||||
}
|
||||
|
||||
async fn from_master_client(mut master_client: Client) -> Self {
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
let uris = master_client.service_discovery().await.unwrap();
|
||||
let futures = uris
|
||||
.into_iter()
|
||||
.map(|path| Client::connect_uds(path, Duration::from_secs(5)));
|
||||
let clients = join_all(futures).await;
|
||||
Self::new(clients)
|
||||
let futures = uris.into_iter().map(|path| Client::connect_uds(path));
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
|
||||
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
|
||||
let master_client = Client::connect(uri, timeout).await;
|
||||
/// Returns a client connected to the given url
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
|
||||
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
|
||||
let master_client = Client::connect_uds(path, timeout).await;
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
|
|
|
@ -1,13 +1,30 @@
|
|||
use crate::server::GenerateRequest;
|
||||
use crate::Db;
|
||||
use crate::{Db, Entry};
|
||||
use axum::http::StatusCode;
|
||||
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{oneshot, Notify};
|
||||
|
||||
const MAX_LENGTH: usize = 128;
|
||||
|
||||
pub struct InferError {}
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InferError {
|
||||
#[error("Request failed during generation: {0}")]
|
||||
GenerationError(String),
|
||||
#[error("Model is overloaded")]
|
||||
Overloaded,
|
||||
}
|
||||
|
||||
impl From<InferError> for (StatusCode, String) {
|
||||
fn from(err: InferError) -> Self {
|
||||
match err {
|
||||
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
|
||||
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Batcher {
|
||||
|
@ -37,14 +54,18 @@ impl Batcher {
|
|||
request: GenerateRequest,
|
||||
) -> Result<String, InferError> {
|
||||
if self.db.len() > MAX_LENGTH {
|
||||
return Err(InferError {});
|
||||
return Err(InferError::Overloaded);
|
||||
}
|
||||
let (request_tx, request_rx) = oneshot::channel();
|
||||
self.db.append(input_length, request, request_tx);
|
||||
self.db.append(Entry {
|
||||
request,
|
||||
response_tx: request_tx,
|
||||
input_length,
|
||||
});
|
||||
self.shared.batching_task.notify_waiters();
|
||||
match request_rx.await.unwrap() {
|
||||
Ok(output) => Ok(output),
|
||||
Err(_) => Err(InferError {}),
|
||||
Err(err) => Err(InferError::GenerationError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -108,7 +129,6 @@ async fn wrap_future(
|
|||
next_batch
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{:?}", err);
|
||||
send_error(err, request_ids, db);
|
||||
None
|
||||
}
|
||||
|
@ -117,14 +137,18 @@ async fn wrap_future(
|
|||
|
||||
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
|
||||
request_ids.into_iter().for_each(|id| {
|
||||
let (_, response_tx) = db.remove(&id).unwrap();
|
||||
response_tx.send(Err(error.clone())).unwrap_or(());
|
||||
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Err(error.clone())).unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
|
||||
finished.into_iter().for_each(|output| {
|
||||
let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap();
|
||||
response_tx.send(Ok(output.output)).unwrap_or(());
|
||||
let entry = db
|
||||
.remove(&output.request.unwrap().id)
|
||||
.expect("ID not found in db. This is a bug.");
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry.response_tx.send(Ok(output.output)).unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,11 +1,29 @@
|
|||
/// This code is massively inspired by Tokio mini-redis
|
||||
use crate::server::GenerateRequest;
|
||||
use crate::server::{GenerateParameters, GenerateRequest};
|
||||
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::oneshot::Sender;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Entry {
|
||||
pub request: GenerateRequest,
|
||||
pub response_tx: Sender<Result<String, ClientError>>,
|
||||
pub input_length: usize,
|
||||
}
|
||||
|
||||
impl From<GenerateParameters> for LogitsWarperParameters {
|
||||
fn from(parameters: GenerateParameters) -> Self {
|
||||
Self {
|
||||
temperature: parameters.temperature,
|
||||
top_k: parameters.top_k as u32,
|
||||
top_p: parameters.top_p,
|
||||
do_sample: parameters.do_sample,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Db {
|
||||
pub shared: Arc<Shared>,
|
||||
|
@ -18,7 +36,7 @@ pub struct Shared {
|
|||
|
||||
#[derive(Debug)]
|
||||
struct State {
|
||||
entries: BTreeMap<u64, (Request, Sender<Result<String, ClientError>>)>,
|
||||
entries: BTreeMap<u64, Entry>,
|
||||
|
||||
/// Identifier to use for the next expiration. Each expiration is associated
|
||||
/// with a unique identifier. See above for why.
|
||||
|
@ -44,37 +62,16 @@ impl Db {
|
|||
Self { shared }
|
||||
}
|
||||
|
||||
pub(crate) fn append(
|
||||
&self,
|
||||
input_length: usize,
|
||||
request: GenerateRequest,
|
||||
sender: Sender<Result<String, ClientError>>,
|
||||
) {
|
||||
pub(crate) fn append(&self, entry: Entry) {
|
||||
let mut state = self.shared.state.write();
|
||||
|
||||
let id = state.next_id;
|
||||
state.next_id += 1;
|
||||
|
||||
let parameters = Some(LogitsWarperParameters {
|
||||
temperature: request.parameters.temperature,
|
||||
top_k: request.parameters.top_k,
|
||||
top_p: request.parameters.top_p,
|
||||
do_sample: request.parameters.do_sample,
|
||||
});
|
||||
let request = Request {
|
||||
id,
|
||||
inputs: request.inputs,
|
||||
input_length: input_length as u32,
|
||||
parameters,
|
||||
max_new_tokens: request.parameters.max_new_tokens,
|
||||
};
|
||||
state.entries.insert(id, (request, sender));
|
||||
state.entries.insert(id, entry);
|
||||
}
|
||||
|
||||
pub(crate) fn remove(
|
||||
&self,
|
||||
id: &u64,
|
||||
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
|
||||
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
|
||||
let mut state = self.shared.state.write();
|
||||
state.entries.remove(id)
|
||||
}
|
||||
|
@ -91,7 +88,15 @@ impl Db {
|
|||
.entries
|
||||
.range(state.next_batch_start_id..)
|
||||
.take(max_size)
|
||||
.map(|(_, (request, _))| request.clone())
|
||||
.map(|(id, entry)| Request {
|
||||
id: *id,
|
||||
inputs: entry.request.inputs.clone(),
|
||||
input_length: entry.input_length as u32,
|
||||
parameters: Some(LogitsWarperParameters::from(
|
||||
entry.request.parameters.clone(),
|
||||
)),
|
||||
max_new_tokens: entry.request.parameters.max_new_tokens,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if requests.is_empty() {
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
mod batcher;
|
||||
mod db;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
|
||||
use batcher::Batcher;
|
||||
use db::{Db, Entry};
|
||||
use validation::Validation;
|
|
@ -1,21 +1,8 @@
|
|||
use bloom_inference_client::ShardedClient;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod server;
|
||||
mod validation;
|
||||
|
||||
use validation::Validation;
|
||||
|
||||
mod db;
|
||||
|
||||
use db::Db;
|
||||
|
||||
mod batcher;
|
||||
|
||||
use batcher::Batcher;
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
|
||||
|
||||
|
@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.block_on(async {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let sharded_client = ShardedClient::connect_uds(
|
||||
"/tmp/bloom-inference-0".to_string(),
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
sharded_client
|
||||
.clear_cache()
|
||||
.await
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use bloom_inference_client::ShardedClient;
|
||||
use crate::{Batcher, Validation};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::StatusCode;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use tokenizers::Tokenizer;
|
||||
|
@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters {
|
|||
#[serde(default = "default_temperature")]
|
||||
pub temperature: f32,
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: u32,
|
||||
pub top_k: i32,
|
||||
#[serde(default = "default_top_p")]
|
||||
pub top_p: f32,
|
||||
#[serde(default = "default_do_sample")]
|
||||
|
@ -28,7 +28,7 @@ fn default_temperature() -> f32 {
|
|||
1.0
|
||||
}
|
||||
|
||||
fn default_top_k() -> u32 {
|
||||
fn default_top_k() -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
|
@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest {
|
|||
}
|
||||
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
|
||||
let output = state
|
||||
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
|
||||
state
|
||||
.infer
|
||||
.infer(
|
||||
1,
|
||||
|
@ -78,50 +78,37 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
|
|||
},
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match output {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn generate(
|
||||
state: Extension<ServerState>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let start = Instant::now();
|
||||
|
||||
let (input_length, validated_request) = match state
|
||||
let (input_length, validated_request) = state
|
||||
.validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
};
|
||||
.await?;
|
||||
|
||||
let output = state.infer.infer(input_length, validated_request).await;
|
||||
let generated_text = state.infer.infer(input_length, validated_request).await?;
|
||||
|
||||
match output {
|
||||
Ok(generated_text) => {
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record(
|
||||
"time_per_token",
|
||||
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
||||
);
|
||||
tracing::info!("response: {}", generated_text);
|
||||
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
|
||||
tracing::Span::current().record(
|
||||
"time_per_token",
|
||||
format!("{:?}", start.elapsed() / req.parameters.max_new_tokens),
|
||||
);
|
||||
tracing::info!("response: {}", generated_text);
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"generated_text": generated_text,
|
||||
})))
|
||||
}
|
||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
Ok(Json(serde_json::json!({
|
||||
"generated_text": generated_text,
|
||||
})))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
|
@ -1,9 +1,28 @@
|
|||
use crate::server::GenerateRequest;
|
||||
use axum::http::StatusCode;
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ValidationError {}
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ValidationError {
|
||||
#[error("Temperature must be strictly positive")]
|
||||
Temperature,
|
||||
#[error("Top p must be <= 0.0 or > 1.0")]
|
||||
TopP,
|
||||
#[error("Top k must be strictly positive")]
|
||||
TopK,
|
||||
#[error("Max New Tokens must be < 512")]
|
||||
MaxNewTokens,
|
||||
#[error("Inputs must have less than 512 tokens. Given: {0}")]
|
||||
InputLength(usize),
|
||||
}
|
||||
|
||||
impl From<ValidationError> for (StatusCode, String) {
|
||||
fn from(err: ValidationError) -> Self {
|
||||
(StatusCode::BAD_REQUEST, err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
type ValidationRequest = (
|
||||
GenerateRequest,
|
||||
|
@ -39,15 +58,23 @@ impl Validation {
|
|||
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
|
||||
while let Some((request, response_tx)) = receiver.recv().await {
|
||||
if request.parameters.temperature < 0.0 {
|
||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
||||
response_tx
|
||||
.send(Err(ValidationError::Temperature))
|
||||
.unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
|
||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
||||
response_tx.send(Err(ValidationError::TopP)).unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
if request.parameters.top_k < 0 {
|
||||
response_tx.send(Err(ValidationError::TopK)).unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
if request.parameters.max_new_tokens > 512 {
|
||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
||||
response_tx
|
||||
.send(Err(ValidationError::MaxNewTokens))
|
||||
.unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
|
|||
let input_length = inputs.len();
|
||||
|
||||
if input_length > 512 {
|
||||
response_tx.send(Err(ValidationError {})).unwrap_or(());
|
||||
response_tx
|
||||
.send(Err(ValidationError::InputLength(input_length)))
|
||||
.unwrap_or(());
|
||||
continue;
|
||||
}
|
||||
|
||||
response_tx.send(Ok((input_length, request))).unwrap_or(());
|
||||
}
|
||||
println!("drop here");
|
||||
}
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
||||
$server_cmd &
|
||||
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
|
||||
|
||||
# Run in background
|
||||
$server_cmd 2>&1 > /dev/null &
|
||||
|
||||
# Check if server is running by checking if the unix socket is created
|
||||
FILE=/tmp/bloom-inference-0
|
||||
|
||||
while :
|
||||
do
|
||||
if test -S "$FILE"; then
|
||||
|
@ -18,4 +20,11 @@ while :
|
|||
|
||||
sleep 1
|
||||
|
||||
exec "bloom-inference"
|
||||
# Run in background
|
||||
text-generation-router &
|
||||
|
||||
# Wait for any process to exit
|
||||
wait -n
|
||||
|
||||
# Exit with status of process that exited first
|
||||
exit $?
|
|
@ -0,0 +1,42 @@
|
|||
import typer
|
||||
|
||||
from pathlib import Path
|
||||
from torch.distributed.launcher import launch_agent, LaunchConfig
|
||||
from typing import Optional
|
||||
|
||||
from bloom_inference import server
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def launcher(
|
||||
model_name: str,
|
||||
num_gpus: int = 1,
|
||||
shard_directory: Optional[Path] = None,
|
||||
):
|
||||
if num_gpus == 1:
|
||||
serve(model_name, False, shard_directory)
|
||||
|
||||
else:
|
||||
config = LaunchConfig(
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
nproc_per_node=num_gpus,
|
||||
rdzv_backend="c10d",
|
||||
max_restarts=0,
|
||||
)
|
||||
launch_agent(config, server.serve, [model_name, True, shard_directory])
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_name: str,
|
||||
sharded: bool = False,
|
||||
shard_directory: Optional[Path] = None,
|
||||
):
|
||||
server.serve(model_name, sharded, shard_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
|
@ -1,30 +0,0 @@
|
|||
import typer
|
||||
|
||||
from pathlib import Path
|
||||
from torch.distributed.launcher import launch_agent, LaunchConfig
|
||||
from typing import Optional
|
||||
|
||||
from bloom_inference.server import serve
|
||||
|
||||
|
||||
def main(
|
||||
model_name: str,
|
||||
num_gpus: int = 1,
|
||||
shard_directory: Optional[Path] = None,
|
||||
):
|
||||
if num_gpus == 1:
|
||||
serve(model_name, False, shard_directory)
|
||||
|
||||
else:
|
||||
config = LaunchConfig(
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
nproc_per_node=num_gpus,
|
||||
rdzv_backend="c10d",
|
||||
max_restarts=0,
|
||||
)
|
||||
launch_agent(config, serve, [model_name, True, shard_directory])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
typer.run(main)
|
|
@ -1,4 +1,6 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
from grpc import aio
|
||||
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory):
|
|||
await server.wait_for_termination()
|
||||
|
||||
asyncio.run(serve_inner(model_name, sharded, shard_directory))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve("bigscience/bloom-560m", True, Path("/tmp/models"))
|
||||
|
|
|
@ -2,6 +2,8 @@ import os
|
|||
import contextlib
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from datetime import timedelta
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
|
@ -79,6 +81,7 @@ def initialize_torch_distributed():
|
|||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timedelta(seconds=60),
|
||||
init_method="tcp://localhost:6000",
|
||||
)
|
||||
|
||||
|
|
|
@ -4,6 +4,9 @@ version = "0.1.0"
|
|||
description = "BLOOM Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
[tool.poetry.scripts]
|
||||
bloom-inference-server = 'bloom_inference.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
protobuf = "^4.21.7"
|
||||
|
|
Loading…
Reference in New Issue