This commit is contained in:
Olivier Dehaene 2022-10-11 18:14:39 +02:00
parent e86ecbac63
commit 39df4d9975
5 changed files with 121 additions and 148 deletions

202
router/Cargo.lock generated
View File

@ -81,6 +81,53 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.5.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"tower-layer",
"tower-service",
]
[[package]]
name = "base64"
version = "0.13.0"
@ -106,10 +153,10 @@ dependencies = [
name = "bloom-inference"
version = "0.1.0"
dependencies = [
"axum",
"bloom-inference-client",
"futures",
"parking_lot",
"poem",
"serde",
"serde_json",
"tokenizers",
@ -661,31 +708,6 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "headers"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
dependencies = [
"base64",
"bitflags",
"bytes",
"headers-core",
"http",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http",
]
[[package]]
name = "heck"
version = "0.3.3"
@ -726,6 +748,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "httparse"
version = "1.8.0"
@ -941,6 +969,12 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
[[package]]
name = "matchit"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
[[package]]
name = "memchr"
version = "2.5.0"
@ -1201,65 +1235,12 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
[[package]]
name = "poem"
version = "1.3.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2992ba72908e36200671c0f3a692992ced894b3b2bbe2b2dc6dfbffea6e2c85a"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"headers",
"http",
"hyper",
"mime",
"parking_lot",
"percent-encoding",
"pin-project-lite",
"poem-derive",
"regex",
"rfc7239",
"serde",
"serde_json",
"serde_urlencoded",
"smallvec",
"thiserror",
"tokio",
"tokio-stream",
"tokio-util 0.7.4",
"tracing",
]
[[package]]
name = "poem-derive"
version = "1.3.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f535d4331a22610b98ca48f98bae9bda0c654da89b9ae10a1830fa9edfd8f36"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "ppv-lite86"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro-crate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9"
dependencies = [
"once_cell",
"thiserror",
"toml",
]
[[package]]
name = "proc-macro2"
version = "1.0.46"
@ -1479,15 +1460,6 @@ dependencies = [
"winreg",
]
[[package]]
name = "rfc7239"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "087317b3cf7eb481f13bd9025d729324b7cd068d6f470e2d76d049e191f5ba47"
dependencies = [
"uncased",
]
[[package]]
name = "ryu"
version = "1.0.11"
@ -1576,17 +1548,6 @@ dependencies = [
"serde",
]
[[package]]
name = "sha1"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha2"
version = "0.10.6"
@ -1667,6 +1628,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
[[package]]
name = "tar"
version = "0.4.38"
@ -1890,15 +1857,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "toml"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7"
dependencies = [
"serde",
]
[[package]]
name = "tonic"
version = "0.6.2"
@ -1962,6 +1920,25 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-http"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.1"
@ -2065,15 +2042,6 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "uncased"
version = "0.9.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09b01702b0fd0b3fadcf98e098780badda8742d4f4a7676615cad90e8ac73622"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.8"

View File

@ -3,13 +3,11 @@ name = "bloom-inference"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = { version = "0.5.16", features = ["json", "serde_json"] }
bloom-inference-client = { path = "client" }
futures = "0.3.24"
parking_lot = "0.12.1"
poem = "1.3.45"
serde = "1.0.145"
serde_json = "1.0.85"
tokenizers = "0.13.0"

View File

@ -4,12 +4,12 @@ version = "0.1.0"
edition = "2021"
[dependencies]
futures = "0.3.24"
futures = "^0.3"
#grpc-error-details = { path = "../../grpc-error-details" }
#grpc-metadata = { path = "../../grpc-metadata" }
prost = "^0.9"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["sync"] }
thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] }
tonic = "^0.6"
tower = "^0.4"
tracing = "^0.1"

View File

@ -1,5 +1,5 @@
use std::net::SocketAddr;
use bloom_inference_client::ShardedClient;
use poem::listener::TcpListener;
use std::time::Duration;
use tokenizers::Tokenizer;
@ -37,9 +37,9 @@ fn main() -> Result<(), std::io::Error> {
.expect("Unable to clear cache");
tracing::info!("Connected");
let addr = "127.0.0.1:3000".to_string();
let listener = TcpListener::bind(addr);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
server::run(sharded_client, tokenizer, listener).await
server::run(sharded_client, tokenizer, addr).await;
Ok(())
})
}

View File

@ -1,9 +1,9 @@
use std::net::SocketAddr;
use axum::{Router, Json};
use axum::http::StatusCode;
use axum::extract::Extension;
use axum::routing::post;
use crate::{Batcher, ShardedClient, Validation};
use poem::http::StatusCode;
use poem::listener::TcpListener;
use poem::middleware::AddData;
use poem::web::{Data, Json};
use poem::{handler, post, EndpointExt, Route, Server};
use serde::Deserialize;
use tokenizers::Tokenizer;
use tokio::time::Instant;
@ -60,26 +60,24 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}
#[handler]
#[instrument(skip(validation, infer), fields(time, time_per_token))]
#[instrument(skip(state), fields(time, time_per_token))]
async fn generate(
validation: Data<&Validation>,
infer: Data<&Batcher>,
state: Extension<ServerState>,
req: Json<GenerateRequest>,
) -> poem::Result<Json<serde_json::Value>> {
) -> Result<Json<serde_json::Value>, StatusCode> {
let start = Instant::now();
let (input_length, validated_request) = match validation
let (input_length, validated_request) = match state.validation
.validate(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await {
Ok(result) => result,
Err(_) => return Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR)
};
let output = infer.infer(input_length, validated_request).await;
let output = state.infer.infer(input_length, validated_request).await;
match output {
Ok(generated_text) => {
@ -94,15 +92,21 @@ async fn generate(
"generated_text": generated_text,
})))
}
Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)),
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
}
}
#[derive(Clone)]
struct ServerState {
validation: Validation,
infer: Batcher,
}
pub async fn run(
client: ShardedClient,
tokenizer: Tokenizer,
listener: TcpListener<String>,
) -> Result<(), std::io::Error> {
addr: SocketAddr,
) {
client.clear_cache().await.expect("Unable to clear cache");
tracing::info!("Connected");
@ -110,10 +114,13 @@ pub async fn run(
let validation = Validation::new(tokenizer);
let app = Route::new()
.at("/generate", post(generate))
.with(AddData::new(validation))
.with(AddData::new(infer));
let shared_state = ServerState {
validation,
infer,
};
Server::new(listener).run(app).await
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state));
axum::Server::bind(&addr)
.serve(app.into_make_service()).await.unwrap();
}