feature: get trace id from req headers

This commit is contained in:
kozistr 2024-10-15 15:14:20 +09:00
parent 58848cb471
commit 64b0337574
3 changed files with 92 additions and 4 deletions

View File

@ -1,13 +1,68 @@
use axum::{extract::Request, middleware::Next, response::Response};
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler; use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource; use opentelemetry::sdk::Resource;
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
use opentelemetry::Context;
use opentelemetry::{global, KeyValue}; use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
struct TraceParent {
#[allow(dead_code)]
version: u8,
trace_id: TraceId,
parent_id: SpanId,
trace_flags: TraceFlags,
}
fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
let parts: Vec<&str> = header_value.split('-').collect();
if parts.len() != 4 {
return None;
}
let version = u8::from_str_radix(parts[0], 16).ok()?;
if version == 0xff {
return None;
}
let trace_id = TraceId::from_hex(parts[1]).ok()?;
let parent_id = SpanId::from_hex(parts[2]).ok()?;
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;
Some(TraceParent {
version,
trace_id,
parent_id,
trace_flags: TraceFlags::new(trace_flags),
})
}
pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {
let context = request
.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.and_then(parse_traceparent)
.map(|traceparent| {
Context::new().with_remote_span_context(SpanContext::new(
traceparent.trace_id,
traceparent.parent_id,
traceparent.trace_flags,
true,
Default::default(),
))
});
request.extensions_mut().insert(context);
next.run(request).await
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: /// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector /// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM /// - otlp_service_name service name to appear in APM

View File

@ -7,6 +7,7 @@ use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer, kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::logging::trace_context_middleware;
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility; use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse; use crate::ChatTokenizeResponse;
@ -57,6 +58,7 @@ use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
@ -87,6 +89,7 @@ async fn compat_generate(
Extension(default_return_full_text): Extension<bool>, Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>, infer: Extension<Infer>,
compute_type: Extension<ComputeType>, compute_type: Extension<ComputeType>,
context: Extension<Option<opentelemetry::Context>>,
Json(mut req): Json<CompatGenerateRequest>, Json(mut req): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// default return_full_text given the pipeline_tag // default return_full_text given the pipeline_tag
@ -96,11 +99,14 @@ async fn compat_generate(
// switch on stream // switch on stream
if req.stream { if req.stream {
Ok(generate_stream(infer, compute_type, Json(req.into())) Ok(
generate_stream(infer, compute_type, context, Json(req.into()))
.await .await
.into_response()) .into_response(),
)
} else { } else {
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?; let (headers, Json(generation)) =
generate(infer, compute_type, context, Json(req.into())).await?;
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![generation])).into_response()) Ok((headers, Json(vec![generation])).into_response())
} }
@ -251,9 +257,14 @@ seed,
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
Extension(ComputeType(compute_type)): Extension<ComputeType>, Extension(ComputeType(compute_type)): Extension<ComputeType>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
generate_internal(infer, ComputeType(compute_type), Json(req), span).await generate_internal(infer, ComputeType(compute_type), Json(req), span).await
} }
@ -447,12 +458,17 @@ seed,
async fn generate_stream( async fn generate_stream(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> ( ) -> (
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let span = tracing::Span::current(); let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), span).await; generate_stream_internal(infer, compute_type, Json(req), span).await;
@ -682,9 +698,14 @@ async fn completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(req): Json<CompletionRequest>, Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
let CompletionRequest { let CompletionRequest {
@ -1206,9 +1227,14 @@ async fn chat_completions(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(chat): Json<ChatRequest>, Json(chat): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
let ChatRequest { let ChatRequest {
stream, stream,
@ -2348,6 +2374,7 @@ async fn start(
.layer(Extension(compute_type)) .layer(Extension(compute_type))
.layer(Extension(prom_handle.clone())) .layer(Extension(prom_handle.clone()))
.layer(OtelAxumLayer::default()) .layer(OtelAxumLayer::default())
.layer(axum::middleware::from_fn(trace_context_middleware))
.layer(cors_layer); .layer(cors_layer);
tracing::info!("Connected"); tracing::info!("Connected");

View File

@ -10,6 +10,7 @@ use axum::response::{IntoResponse, Response};
use axum::Json; use axum::Json;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::instrument; use tracing::instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use utoipa::ToSchema; use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
@ -223,9 +224,14 @@ example = json ! ({"error": "Incomplete generation"})),
pub(crate) async fn vertex_compatibility( pub(crate) async fn vertex_compatibility(
Extension(infer): Extension<Infer>, Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(req): Json<VertexRequest>, Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance // check that theres at least one instance