diff --git a/router/Cargo.lock b/router/Cargo.lock index 79761345..1f00df14 100644 --- a/router/Cargo.lock +++ b/router/Cargo.lock @@ -81,6 +81,53 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.5.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "tower-layer", + "tower-service", +] + [[package]] name = "base64" version = "0.13.0" @@ -106,10 +153,10 @@ dependencies = [ name = "bloom-inference" version = "0.1.0" dependencies = [ + "axum", "bloom-inference-client", "futures", "parking_lot", - "poem", "serde", "serde_json", "tokenizers", @@ -661,31 +708,6 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "headers" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" -dependencies = [ - "base64", - "bitflags", - "bytes", - "headers-core", - "http", - "httpdate", - "mime", - "sha1", -] - -[[package]] -name = "headers-core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" -dependencies = [ - "http", -] - [[package]] name = "heck" version = "0.3.3" @@ -726,6 +748,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -941,6 +969,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea" +[[package]] +name = "matchit" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" + [[package]] name = "memchr" version = "2.5.0" @@ -1201,65 +1235,12 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" -[[package]] -name = "poem" -version = "1.3.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2992ba72908e36200671c0f3a692992ced894b3b2bbe2b2dc6dfbffea6e2c85a" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "headers", - "http", - "hyper", - "mime", - "parking_lot", - "percent-encoding", - "pin-project-lite", - "poem-derive", - "regex", - "rfc7239", - "serde", - "serde_json", - "serde_urlencoded", - "smallvec", - "thiserror", - "tokio", - "tokio-stream", - "tokio-util 0.7.4", - "tracing", -] - -[[package]] -name = "poem-derive" -version = "1.3.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f535d4331a22610b98ca48f98bae9bda0c654da89b9ae10a1830fa9edfd8f36" -dependencies = [ - "proc-macro-crate", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "ppv-lite86" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" -[[package]] -name = "proc-macro-crate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9" -dependencies = [ - "once_cell", - "thiserror", - "toml", -] - [[package]] name = "proc-macro2" version = "1.0.46" @@ -1479,15 +1460,6 @@ dependencies = [ "winreg", ] -[[package]] -name = "rfc7239" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "087317b3cf7eb481f13bd9025d729324b7cd068d6f470e2d76d049e191f5ba47" -dependencies = [ - "uncased", -] - [[package]] name = "ryu" version = "1.0.11" @@ -1576,17 +1548,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha1" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha2" version = "0.10.6" @@ -1667,6 +1628,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" + [[package]] name = "tar" version = "0.4.38" @@ -1890,15 +1857,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "toml" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7" -dependencies = [ - "serde", -] - [[package]] name = "tonic" version = "0.6.2" @@ -1962,6 +1920,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.1" @@ -2065,15 +2042,6 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" -[[package]] -name = "uncased" -version = "0.9.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09b01702b0fd0b3fadcf98e098780badda8742d4f4a7676615cad90e8ac73622" -dependencies = [ - "version_check", -] - [[package]] name = "unicode-bidi" version = "0.3.8" diff --git a/router/Cargo.toml b/router/Cargo.toml index 901a14f2..c5e5bb89 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -3,13 +3,11 @@ name = "bloom-inference" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +axum = { version = "0.5.16", features = ["json", "serde_json"] } bloom-inference-client = { path = "client" } futures = "0.3.24" parking_lot = "0.12.1" -poem = "1.3.45" serde = "1.0.145" serde_json = "1.0.85" tokenizers = "0.13.0" diff --git a/router/client/Cargo.toml b/router/client/Cargo.toml index e4f2fa0f..7760c8cb 100644 --- a/router/client/Cargo.toml +++ b/router/client/Cargo.toml @@ -4,12 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] -futures = "0.3.24" +futures = "^0.3" #grpc-error-details = { path = "../../grpc-error-details" } #grpc-metadata = { path = "../../grpc-metadata" } prost = "^0.9" -thiserror = "1.0.37" -tokio = { version = "1.21.2", features = ["sync"] } +thiserror = "^1.0" +tokio = { version = "^1.21", features = ["sync"] } tonic = "^0.6" tower = "^0.4" tracing = "^0.1" diff --git a/router/src/main.rs b/router/src/main.rs index fe82d059..803753b1 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -1,5 +1,5 @@ +use std::net::SocketAddr; use bloom_inference_client::ShardedClient; -use poem::listener::TcpListener; use std::time::Duration; use tokenizers::Tokenizer; @@ -37,9 +37,9 @@ fn main() -> Result<(), std::io::Error> { .expect("Unable to clear cache"); tracing::info!("Connected"); - let addr = "127.0.0.1:3000".to_string(); - let listener = TcpListener::bind(addr); + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - server::run(sharded_client, tokenizer, listener).await + server::run(sharded_client, tokenizer, addr).await; + Ok(()) }) } diff --git a/router/src/server.rs b/router/src/server.rs index 64b37ff8..61c57069 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,9 +1,9 @@ +use std::net::SocketAddr; +use axum::{Router, Json}; +use axum::http::StatusCode; +use axum::extract::Extension; +use axum::routing::post; use crate::{Batcher, ShardedClient, Validation}; -use poem::http::StatusCode; -use poem::listener::TcpListener; -use poem::middleware::AddData; -use poem::web::{Data, Json}; -use poem::{handler, post, EndpointExt, Route, Server}; use serde::Deserialize; use tokenizers::Tokenizer; use tokio::time::Instant; @@ -60,26 +60,24 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } -#[handler] -#[instrument(skip(validation, infer), fields(time, time_per_token))] +#[instrument(skip(state), fields(time, time_per_token))] async fn generate( - validation: Data<&Validation>, - infer: Data<&Batcher>, + state: Extension, req: Json, -) -> poem::Result> { +) -> Result, StatusCode> { let start = Instant::now(); - let (input_length, validated_request) = match validation + let (input_length, validated_request) = match state.validation .validate(GenerateRequest { inputs: req.inputs.clone(), parameters: req.parameters.clone(), }) .await { Ok(result) => result, - Err(_) => return Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)) + Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR) }; - let output = infer.infer(input_length, validated_request).await; + let output = state.infer.infer(input_length, validated_request).await; match output { Ok(generated_text) => { @@ -94,15 +92,21 @@ async fn generate( "generated_text": generated_text, }))) } - Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)), + Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), } } +#[derive(Clone)] +struct ServerState { + validation: Validation, + infer: Batcher, +} + pub async fn run( client: ShardedClient, tokenizer: Tokenizer, - listener: TcpListener, -) -> Result<(), std::io::Error> { + addr: SocketAddr, +) { client.clear_cache().await.expect("Unable to clear cache"); tracing::info!("Connected"); @@ -110,10 +114,13 @@ pub async fn run( let validation = Validation::new(tokenizer); - let app = Route::new() - .at("/generate", post(generate)) - .with(AddData::new(validation)) - .with(AddData::new(infer)); + let shared_state = ServerState { + validation, + infer, + }; - Server::new(listener).run(app).await + let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state)); + + axum::Server::bind(&addr) + .serve(app.into_make_service()).await.unwrap(); }