diff --git a/Cargo.lock b/Cargo.lock index 7fdf301a..c3a0a49a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3163,6 +3163,8 @@ dependencies = [ "pin-project-lite", "tower-layer", "tower-service", + "tracing", + "uuid", ] [[package]] @@ -3456,6 +3458,15 @@ dependencies = [ "zip", ] +[[package]] +name = "uuid" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +dependencies = [ + "getrandom", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/router/Cargo.toml b/router/Cargo.toml index f6f16dae..687d3421 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -35,7 +35,7 @@ thiserror = "1.0.48" tokenizers = { version = "0.14.0", features = ["http"] } tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.14" -tower-http = { version = "0.4.4", features = ["cors"] } +tower-http = { version = "0.4.4", features = ["cors", "request-id", "trace"] } tracing = "0.1.37" tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } diff --git a/router/src/server.rs b/router/src/server.rs index 52ed03df..16c5a62f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -9,7 +9,7 @@ use crate::{ PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, }; use axum::extract::Extension; -use axum::http::{HeaderMap, Method, StatusCode}; +use axum::http::{HeaderMap, Method, Request, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; @@ -27,6 +27,9 @@ use tokenizers::Tokenizer; use tokio::signal; use tokio::time::Instant; use tower_http::cors::{AllowOrigin, CorsLayer}; +use tower_http::request_id::{ + MakeRequestUuid, PropagateRequestIdLayer, RequestId, SetRequestIdLayer, +}; use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -967,7 +970,35 @@ pub async fn run( .layer(Extension(compute_type)) .layer(Extension(prom_handle.clone())) .layer(OtelAxumLayer::default()) - .layer(cors_layer); + .layer(cors_layer) + .layer( + ::tower_http::trace::TraceLayer::new_for_http() + .make_span_with(RequestSpan::new()) + .on_request(|req: &::axum::http::Request<_>, _span: &::tracing::Span| { + tracing::info!("Request: {} {}", req.method(), req.uri(),); + }) + .on_response( + |res: &::axum::http::Response<_>, + latency: ::std::time::Duration, + _span: &::tracing::Span| { + tracing::info!( + took = latency.as_secs_f32(), + status_code = res.status().as_u16(), + "Response: {}", + res.status().as_u16(), + ); + }, + ) + .on_failure( + |error: ::tower_http::classify::ServerErrorsFailureClass, + latency: ::std::time::Duration, + _span: &::tracing::Span| { + ::tracing::warn!(took = latency.as_secs_f32(), "Failure: {error:?}"); + }, + ), + ) + .layer(PropagateRequestIdLayer::x_request_id()) + .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)); if ngrok { #[cfg(feature = "ngrok")] @@ -1103,3 +1134,29 @@ impl From for Event { .unwrap() } } + +#[derive(Clone)] +pub struct RequestSpan {} + +impl RequestSpan { + pub fn new() -> Self { + Self {} + } +} + +impl tower_http::trace::MakeSpan for RequestSpan { + fn make_span(&mut self, req: &Request) -> tracing::Span { + // SAFETY: Added by request ID middleware + let request_id = req + .extensions() + .get::() + .map(|s| String::from_utf8_lossy(s.header_value().as_bytes())) + .unwrap_or_default(); + + tracing::info_span!("request", + %request_id, + method = %req.method(), + uri = %req.uri(), + ) + } +}