feat: Improve error handling

This commit is contained in:
Olivier Dehaene 2022-10-17 14:59:00 +02:00
parent 00e6ce44b1
commit 5e5d8766a2
20 changed files with 267 additions and 203 deletions

View File

@ -18,6 +18,7 @@ ENV LANG=C.UTF-8 \
MODEL_NAME=bigscience/bloom \
NUM_GPUS=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
LD_LIBRARY_PATH="/opt/miniconda/envs/text-generation/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" \
CONDA_DEFAULT_ENV=text-generation \
@ -51,7 +52,7 @@ RUN cd server && \
/opt/miniconda/envs/text-generation/bin/pip install . --no-cache-dir
# Install router
COPY --from=builder /usr/local/cargo/bin/bloom-inference /usr/local/bin/bloom-inference
COPY --from=builder /usr/local/cargo/bin/text-generation-router /usr/local/bin/text-generation-router
COPY run.sh .
RUN chmod +x run.sh

View File

@ -48,5 +48,4 @@ python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-dire
- [ ] Add tests
- [ ] Add shutdown logic in router and server
- [ ] Improve multi-processing logic in server
- [ ] Improve error handling everywhere
- [ ] Improve past key layer indexing?

View File

@ -8,7 +8,7 @@ environment_variables:
MODEL_NAME: bigscience/bloom
NUM_GPUS: 8
environment:
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
image: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.3
inference_config:
liveness_route:
port: 3000
@ -24,15 +24,15 @@ request_settings:
request_timeout_ms: 90000
max_concurrent_requests_per_instance: 256
liveness_probe:
initial_delay: 300
initial_delay: 600
timeout: 20
period: 60
period: 120
success_threshold: 1
failure_threshold: 60
failure_threshold: 3
readiness_probe:
initial_delay: 300
initial_delay: 600
timeout: 20
period: 60
period: 120
success_threshold: 1
failure_threshold: 60
failure_threshold: 3
instance_count: 1

33
router/Cargo.lock generated
View File

@ -149,22 +149,6 @@ dependencies = [
"generic-array",
]
[[package]]
name = "bloom-inference"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"serde",
"serde_json",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "bloom-inference-client"
version = "0.1.0"
@ -1669,6 +1653,23 @@ dependencies = [
"winapi",
]
[[package]]
name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"serde",
"serde_json",
"thiserror",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "textwrap"
version = "0.11.0"

View File

@ -1,8 +1,15 @@
[package]
name = "bloom-inference"
name = "text-generation-router"
version = "0.1.0"
edition = "2021"
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" }
@ -10,6 +17,7 @@ futures = "0.3.24"
parking_lot = "0.12.1"
serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
tracing = "0.1.36"

View File

@ -1,45 +1,37 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tower::timeout::Timeout;
use tracing::*;
/// BLOOM Inference gRPC client
#[derive(Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Timeout<Channel>>,
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
let channel = Channel::builder(uri)
.connect()
.await
.expect("Transport error");
let timeout_channel = Timeout::new(channel, timeout);
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Self {
stub: TextGenerationServiceClient::new(timeout_channel),
}
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await
.expect("Transport error");
let timeout_channel = Timeout::new(channel, timeout);
.await?;
Self {
stub: TextGenerationServiceClient::new(timeout_channel),
}
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
#[instrument(skip(self))]

View File

@ -8,23 +8,27 @@ pub use client::Client;
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
pub use sharded_client::ShardedClient;
use thiserror::Error;
pub use tonic::transport::Uri;
pub use tonic::transport;
use tonic::Status;
#[derive(Error, Debug, Clone)]
#[error("Text generation client error: {msg:?}")]
pub struct ClientError {
msg: String,
// source: Status,
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0:?}")]
Connection(String),
#[error("Server error: {0:?}")]
Generation(String),
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
Self {
msg: err.to_string(),
// source: err,
Self::Generation(err.to_string())
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
Self::Connection(err.to_string())
}
}
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -1,7 +1,6 @@
use crate::Result;
use crate::{Batch, Client, GeneratedText};
use futures::future::join_all;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;
@ -69,24 +68,22 @@ impl ShardedClient {
Self { request_tx }
}
async fn from_master_client(mut master_client: Client) -> Self {
async fn from_master_client(mut master_client: Client) -> Result<Self> {
let uris = master_client.service_discovery().await.unwrap();
let futures = uris
.into_iter()
.map(|path| Client::connect_uds(path, Duration::from_secs(5)));
let clients = join_all(futures).await;
Self::new(clients)
let futures = uris.into_iter().map(|path| Client::connect_uds(path));
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
let master_client = Client::connect(uri, timeout).await;
/// Returns a client connected to the given url
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
let master_client = Client::connect_uds(path, timeout).await;
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}

View File

@ -1,13 +1,30 @@
use crate::server::GenerateRequest;
use crate::Db;
use crate::{Db, Entry};
use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
const MAX_LENGTH: usize = 128;
pub struct InferError {}
#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
GenerationError(String),
#[error("Model is overloaded")]
Overloaded,
}
impl From<InferError> for (StatusCode, String) {
fn from(err: InferError) -> Self {
match err {
InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
}
}
}
#[derive(Clone)]
pub(crate) struct Batcher {
@ -37,14 +54,18 @@ impl Batcher {
request: GenerateRequest,
) -> Result<String, InferError> {
if self.db.len() > MAX_LENGTH {
return Err(InferError {});
return Err(InferError::Overloaded);
}
let (request_tx, request_rx) = oneshot::channel();
self.db.append(input_length, request, request_tx);
self.db.append(Entry {
request,
response_tx: request_tx,
input_length,
});
self.shared.batching_task.notify_waiters();
match request_rx.await.unwrap() {
Ok(output) => Ok(output),
Err(_) => Err(InferError {}),
Err(err) => Err(InferError::GenerationError(err.to_string())),
}
}
}
@ -108,7 +129,6 @@ async fn wrap_future(
next_batch
}
Err(err) => {
println!("{:?}", err);
send_error(err, request_ids, db);
None
}
@ -117,14 +137,18 @@ async fn wrap_future(
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
request_ids.into_iter().for_each(|id| {
let (_, response_tx) = db.remove(&id).unwrap();
response_tx.send(Err(error.clone())).unwrap_or(());
let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Err(error.clone())).unwrap_or(());
});
}
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
finished.into_iter().for_each(|output| {
let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap();
response_tx.send(Ok(output.output)).unwrap_or(());
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Ok(output.output)).unwrap_or(());
});
}

View File

@ -1,11 +1,29 @@
/// This code is massively inspired by Tokio mini-redis
use crate::server::GenerateRequest;
use crate::server::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::oneshot::Sender;
#[derive(Debug)]
pub(crate) struct Entry {
pub request: GenerateRequest,
pub response_tx: Sender<Result<String, ClientError>>,
pub input_length: usize,
}
impl From<GenerateParameters> for LogitsWarperParameters {
fn from(parameters: GenerateParameters) -> Self {
Self {
temperature: parameters.temperature,
top_k: parameters.top_k as u32,
top_p: parameters.top_p,
do_sample: parameters.do_sample,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct Db {
pub shared: Arc<Shared>,
@ -18,7 +36,7 @@ pub struct Shared {
#[derive(Debug)]
struct State {
entries: BTreeMap<u64, (Request, Sender<Result<String, ClientError>>)>,
entries: BTreeMap<u64, Entry>,
/// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why.
@ -44,37 +62,16 @@ impl Db {
Self { shared }
}
pub(crate) fn append(
&self,
input_length: usize,
request: GenerateRequest,
sender: Sender<Result<String, ClientError>>,
) {
pub(crate) fn append(&self, entry: Entry) {
let mut state = self.shared.state.write();
let id = state.next_id;
state.next_id += 1;
let parameters = Some(LogitsWarperParameters {
temperature: request.parameters.temperature,
top_k: request.parameters.top_k,
top_p: request.parameters.top_p,
do_sample: request.parameters.do_sample,
});
let request = Request {
id,
inputs: request.inputs,
input_length: input_length as u32,
parameters,
max_new_tokens: request.parameters.max_new_tokens,
};
state.entries.insert(id, (request, sender));
state.entries.insert(id, entry);
}
pub(crate) fn remove(
&self,
id: &u64,
) -> Option<(Request, Sender<Result<String, ClientError>>)> {
pub(crate) fn remove(&self, id: &u64) -> Option<Entry> {
let mut state = self.shared.state.write();
state.entries.remove(id)
}
@ -91,7 +88,15 @@ impl Db {
.entries
.range(state.next_batch_start_id..)
.take(max_size)
.map(|(_, (request, _))| request.clone())
.map(|(id, entry)| Request {
id: *id,
inputs: entry.request.inputs.clone(),
input_length: entry.input_length as u32,
parameters: Some(LogitsWarperParameters::from(
entry.request.parameters.clone(),
)),
max_new_tokens: entry.request.parameters.max_new_tokens,
})
.collect();
if requests.is_empty() {

8
router/src/lib.rs Normal file
View File

@ -0,0 +1,8 @@
mod batcher;
mod db;
pub mod server;
mod validation;
use batcher::Batcher;
use db::{Db, Entry};
use validation::Validation;

View File

@ -1,21 +1,8 @@
use bloom_inference_client::ShardedClient;
use std::net::SocketAddr;
use std::time::Duration;
use text_generation_router::server;
use tokenizers::Tokenizer;
mod server;
mod validation;
use validation::Validation;
mod db;
use db::Db;
mod batcher;
use batcher::Batcher;
fn main() -> Result<(), std::io::Error> {
let tokenizer = Tokenizer::from_pretrained("bigscience/bloom", None).unwrap();
@ -26,11 +13,9 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
tracing_subscriber::fmt::init();
let sharded_client = ShardedClient::connect_uds(
"/tmp/bloom-inference-0".to_string(),
Duration::from_secs(5),
)
.await;
let sharded_client = ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string())
.await
.expect("Could not connect to server");
sharded_client
.clear_cache()
.await

View File

@ -1,9 +1,9 @@
use bloom_inference_client::ShardedClient;
use crate::{Batcher, Validation};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::routing::{get, post};
use axum::{Json, Router};
use bloom_inference_client::ShardedClient;
use serde::Deserialize;
use std::net::SocketAddr;
use tokenizers::Tokenizer;
@ -15,7 +15,7 @@ pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_k")]
pub top_k: u32,
pub top_k: i32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
@ -28,7 +28,7 @@ fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> u32 {
fn default_top_k() -> i32 {
0
}
@ -62,8 +62,8 @@ pub(crate) struct GenerateRequest {
}
#[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
let output = state
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
state
.infer
.infer(
1,
@ -78,37 +78,27 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), StatusCode> {
},
},
)
.await;
match output {
Ok(_) => Ok(()),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
.await?;
Ok(())
}
#[instrument(skip(state), fields(time, time_per_token))]
async fn generate(
state: Extension<ServerState>,
req: Json<GenerateRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let start = Instant::now();
let (input_length, validated_request) = match state
let (input_length, validated_request) = state
.validation
.validate(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await
{
Ok(result) => result,
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
.await?;
let output = state.infer.infer(input_length, validated_request).await;
let generated_text = state.infer.infer(input_length, validated_request).await?;
match output {
Ok(generated_text) => {
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record(
"time_per_token",
@ -120,9 +110,6 @@ async fn generate(
"generated_text": generated_text,
})))
}
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
#[derive(Clone)]
struct ServerState {

View File

@ -1,9 +1,28 @@
use crate::server::GenerateRequest;
use axum::http::StatusCode;
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug)]
pub struct ValidationError {}
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Temperature must be strictly positive")]
Temperature,
#[error("Top p must be <= 0.0 or > 1.0")]
TopP,
#[error("Top k must be strictly positive")]
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 512 tokens. Given: {0}")]
InputLength(usize),
}
impl From<ValidationError> for (StatusCode, String) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
}
}
type ValidationRequest = (
GenerateRequest,
@ -39,15 +58,23 @@ impl Validation {
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) {
while let Some((request, response_tx)) = receiver.recv().await {
if request.parameters.temperature < 0.0 {
response_tx.send(Err(ValidationError {})).unwrap_or(());
response_tx
.send(Err(ValidationError::Temperature))
.unwrap_or(());
continue;
}
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
response_tx.send(Err(ValidationError {})).unwrap_or(());
response_tx.send(Err(ValidationError::TopP)).unwrap_or(());
continue;
}
if request.parameters.top_k < 0 {
response_tx.send(Err(ValidationError::TopK)).unwrap_or(());
continue;
}
if request.parameters.max_new_tokens > 512 {
response_tx.send(Err(ValidationError {})).unwrap_or(());
response_tx
.send(Err(ValidationError::MaxNewTokens))
.unwrap_or(());
continue;
}
@ -55,11 +82,12 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
let input_length = inputs.len();
if input_length > 512 {
response_tx.send(Err(ValidationError {})).unwrap_or(());
response_tx
.send(Err(ValidationError::InputLength(input_length)))
.unwrap_or(());
continue;
}
response_tx.send(Ok((input_length, request))).unwrap_or(());
}
println!("drop here");
}

17
run.sh Executable file → Normal file
View File

@ -1,10 +1,12 @@
#!/usr/bin/env bash
server_cmd="python server/bloom_inference/main.py $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
$server_cmd &
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
# Run in background
$server_cmd 2>&1 > /dev/null &
# Check if server is running by checking if the unix socket is created
FILE=/tmp/bloom-inference-0
while :
do
if test -S "$FILE"; then
@ -18,4 +20,11 @@ while :
sleep 1
exec "bloom-inference"
# Run in background
text-generation-router &
# Wait for any process to exit
wait -n
# Exit with status of process that exited first
exit $?

View File

@ -0,0 +1,42 @@
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference import server
app = typer.Typer()
@app.command()
def launcher(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, server.serve, [model_name, True, shard_directory])
@app.command()
def serve(
model_name: str,
sharded: bool = False,
shard_directory: Optional[Path] = None,
):
server.serve(model_name, sharded, shard_directory)
if __name__ == "__main__":
app()

View File

@ -1,30 +0,0 @@
import typer
from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional
from bloom_inference.server import serve
def main(
model_name: str,
num_gpus: int = 1,
shard_directory: Optional[Path] = None,
):
if num_gpus == 1:
serve(model_name, False, shard_directory)
else:
config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
launch_agent(config, serve, [model_name, True, shard_directory])
if __name__ == "__main__":
typer.run(main)

View File

@ -1,4 +1,6 @@
import asyncio
import os
from grpc import aio
from grpc_reflection.v1alpha import reflection
@ -143,7 +145,3 @@ def serve(model_name, sharded, shard_directory):
await server.wait_for_termination()
asyncio.run(serve_inner(model_name, sharded, shard_directory))
if __name__ == "__main__":
serve("bigscience/bloom-560m", True, Path("/tmp/models"))

View File

@ -2,6 +2,8 @@ import os
import contextlib
import torch
import torch.distributed
from datetime import timedelta
from transformers.generation_logits_process import (
LogitsProcessorList,
TemperatureLogitsWarper,
@ -79,6 +81,7 @@ def initialize_torch_distributed():
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
init_method="tcp://localhost:6000",
)

View File

@ -4,6 +4,9 @@ version = "0.1.0"
description = "BLOOM Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
bloom-inference-server = 'bloom_inference.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
protobuf = "^4.21.7"