diff --git a/router/src/logging.rs b/router/src/logging.rs index 5a98ef57..44aa27c5 100644 --- a/router/src/logging.rs +++ b/router/src/logging.rs @@ -1,9 +1,15 @@ +use axum::body::Body; +use axum::http::{HeaderMap, Request}; +use axum::middleware::Next; +use axum::response::Response; +use opentelemetry::propagation::Extractor; use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::trace; use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::Resource; use opentelemetry::{global, KeyValue}; use opentelemetry_otlp::WithExportConfig; +use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; @@ -79,3 +85,30 @@ pub fn init_logging(otlp_endpoint: Option, otlp_service_name: String, js .with(layers) .init(); } + +struct HeaderExtractor<'a>(&'a HeaderMap); + +impl<'a> Extractor for HeaderExtractor<'a> { + fn get(&self, key: &str) -> Option<&str> { + let value = self.0.get(key).and_then(|v| v.to_str().ok()); + value + } + + fn keys(&self) -> Vec<&str> { + let keys: Vec<&str> = self.0.keys().map(|k| k.as_str()).collect(); + keys + } +} + +pub async fn trace_context_middleware(request: Request, next: Next) -> Response { + let parent_ctx = global::get_text_map_propagator(|prop| { + let headers = request.headers(); + let extractor = HeaderExtractor(headers); + prop.extract(&extractor) + }); + + let span = tracing::Span::current(); + span.set_parent(parent_ctx); + + next.run(request).await +} diff --git a/router/src/server.rs b/router/src/server.rs index 8ec7a871..fbbe819d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::ChatTokenizeResponse; +use crate::{logging, ChatTokenizeResponse}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -2305,6 +2305,7 @@ async fn start( .layer(Extension(infer)) .layer(Extension(compute_type)) .layer(Extension(prom_handle.clone())) + .layer(axum::middleware::from_fn(logging::trace_context_middleware)) .layer(OtelAxumLayer::default()) .layer(cors_layer);