Use axum
This commit is contained in:
parent
e86ecbac63
commit
39df4d9975
|
@ -81,6 +81,53 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.5.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9e3356844c4d6a6d6467b8da2cffb4a2820be256f50a3a386c9d152bab31043"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9f0c0a60006f2a293d82d571f635042a72edf927539b7685bd62d361963839b"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"mime",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.0"
|
||||
|
@ -106,10 +153,10 @@ dependencies = [
|
|||
name = "bloom-inference"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"bloom-inference-client",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"poem",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
|
@ -661,31 +708,6 @@ version = "0.12.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"headers-core",
|
||||
"http",
|
||||
"httpdate",
|
||||
"mime",
|
||||
"sha1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers-core"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
|
||||
dependencies = [
|
||||
"http",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.3.3"
|
||||
|
@ -726,6 +748,12 @@ dependencies = [
|
|||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-range-header"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
|
||||
|
||||
[[package]]
|
||||
name = "httparse"
|
||||
version = "1.8.0"
|
||||
|
@ -941,6 +969,12 @@ version = "0.1.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26a8d2502d5aa4d411ef494ba7470eb299f05725179ce3b5de77aa01a9ffdea"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
|
@ -1201,65 +1235,12 @@ version = "0.3.25"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae"
|
||||
|
||||
[[package]]
|
||||
name = "poem"
|
||||
version = "1.3.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2992ba72908e36200671c0f3a692992ced894b3b2bbe2b2dc6dfbffea6e2c85a"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http",
|
||||
"hyper",
|
||||
"mime",
|
||||
"parking_lot",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"poem-derive",
|
||||
"regex",
|
||||
"rfc7239",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util 0.7.4",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "poem-derive"
|
||||
version = "1.3.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9f535d4331a22610b98ca48f98bae9bda0c654da89b9ae10a1830fa9edfd8f36"
|
||||
dependencies = [
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-crate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"thiserror",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.46"
|
||||
|
@ -1479,15 +1460,6 @@ dependencies = [
|
|||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rfc7239"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "087317b3cf7eb481f13bd9025d729324b7cd068d6f470e2d76d049e191f5ba47"
|
||||
dependencies = [
|
||||
"uncased",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.11"
|
||||
|
@ -1576,17 +1548,6 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
|
@ -1667,6 +1628,12 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.38"
|
||||
|
@ -1890,15 +1857,6 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d82e1a7758622a465f8cee077614c73484dac5b836c02ff6a40d5d1010324d7"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.6.2"
|
||||
|
@ -1962,6 +1920,25 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-range-header",
|
||||
"pin-project-lite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.1"
|
||||
|
@ -2065,15 +2042,6 @@ version = "1.15.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||
|
||||
[[package]]
|
||||
name = "uncased"
|
||||
version = "0.9.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09b01702b0fd0b3fadcf98e098780badda8742d4f4a7676615cad90e8ac73622"
|
||||
dependencies = [
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.8"
|
||||
|
|
|
@ -3,13 +3,11 @@ name = "bloom-inference"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.5.16", features = ["json", "serde_json"] }
|
||||
bloom-inference-client = { path = "client" }
|
||||
futures = "0.3.24"
|
||||
parking_lot = "0.12.1"
|
||||
poem = "1.3.45"
|
||||
serde = "1.0.145"
|
||||
serde_json = "1.0.85"
|
||||
tokenizers = "0.13.0"
|
||||
|
|
|
@ -4,12 +4,12 @@ version = "0.1.0"
|
|||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
futures = "0.3.24"
|
||||
futures = "^0.3"
|
||||
#grpc-error-details = { path = "../../grpc-error-details" }
|
||||
#grpc-metadata = { path = "../../grpc-metadata" }
|
||||
prost = "^0.9"
|
||||
thiserror = "1.0.37"
|
||||
tokio = { version = "1.21.2", features = ["sync"] }
|
||||
thiserror = "^1.0"
|
||||
tokio = { version = "^1.21", features = ["sync"] }
|
||||
tonic = "^0.6"
|
||||
tower = "^0.4"
|
||||
tracing = "^0.1"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use std::net::SocketAddr;
|
||||
use bloom_inference_client::ShardedClient;
|
||||
use poem::listener::TcpListener;
|
||||
use std::time::Duration;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
|
@ -37,9 +37,9 @@ fn main() -> Result<(), std::io::Error> {
|
|||
.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
let addr = "127.0.0.1:3000".to_string();
|
||||
let listener = TcpListener::bind(addr);
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
|
||||
server::run(sharded_client, tokenizer, listener).await
|
||||
server::run(sharded_client, tokenizer, addr).await;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use std::net::SocketAddr;
|
||||
use axum::{Router, Json};
|
||||
use axum::http::StatusCode;
|
||||
use axum::extract::Extension;
|
||||
use axum::routing::post;
|
||||
use crate::{Batcher, ShardedClient, Validation};
|
||||
use poem::http::StatusCode;
|
||||
use poem::listener::TcpListener;
|
||||
use poem::middleware::AddData;
|
||||
use poem::web::{Data, Json};
|
||||
use poem::{handler, post, EndpointExt, Route, Server};
|
||||
use serde::Deserialize;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::time::Instant;
|
||||
|
@ -60,26 +60,24 @@ pub(crate) struct GenerateRequest {
|
|||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[handler]
|
||||
#[instrument(skip(validation, infer), fields(time, time_per_token))]
|
||||
#[instrument(skip(state), fields(time, time_per_token))]
|
||||
async fn generate(
|
||||
validation: Data<&Validation>,
|
||||
infer: Data<&Batcher>,
|
||||
state: Extension<ServerState>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> poem::Result<Json<serde_json::Value>> {
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let start = Instant::now();
|
||||
|
||||
let (input_length, validated_request) = match validation
|
||||
let (input_length, validated_request) = match state.validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: req.inputs.clone(),
|
||||
parameters: req.parameters.clone(),
|
||||
})
|
||||
.await {
|
||||
Ok(result) => result,
|
||||
Err(_) => return Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
};
|
||||
|
||||
let output = infer.infer(input_length, validated_request).await;
|
||||
let output = state.infer.infer(input_length, validated_request).await;
|
||||
|
||||
match output {
|
||||
Ok(generated_text) => {
|
||||
|
@ -94,15 +92,21 @@ async fn generate(
|
|||
"generated_text": generated_text,
|
||||
})))
|
||||
}
|
||||
Err(_) => Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR)),
|
||||
Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ServerState {
|
||||
validation: Validation,
|
||||
infer: Batcher,
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
client: ShardedClient,
|
||||
tokenizer: Tokenizer,
|
||||
listener: TcpListener<String>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
addr: SocketAddr,
|
||||
) {
|
||||
client.clear_cache().await.expect("Unable to clear cache");
|
||||
tracing::info!("Connected");
|
||||
|
||||
|
@ -110,10 +114,13 @@ pub async fn run(
|
|||
|
||||
let validation = Validation::new(tokenizer);
|
||||
|
||||
let app = Route::new()
|
||||
.at("/generate", post(generate))
|
||||
.with(AddData::new(validation))
|
||||
.with(AddData::new(infer));
|
||||
let shared_state = ServerState {
|
||||
validation,
|
||||
infer,
|
||||
};
|
||||
|
||||
Server::new(listener).run(app).await
|
||||
let app = Router::new().route("/generate", post(generate)).layer(Extension(shared_state));
|
||||
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service()).await.unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue