feature: get trace id from req headers
This commit is contained in:
parent
58848cb471
commit
64b0337574
|
@ -1,13 +1,68 @@
|
||||||
|
use axum::{extract::Request, middleware::Next, response::Response};
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
use opentelemetry::sdk::trace::Sampler;
|
use opentelemetry::sdk::trace::Sampler;
|
||||||
use opentelemetry::sdk::Resource;
|
use opentelemetry::sdk::Resource;
|
||||||
|
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
|
||||||
|
use opentelemetry::Context;
|
||||||
use opentelemetry::{global, KeyValue};
|
use opentelemetry::{global, KeyValue};
|
||||||
use opentelemetry_otlp::WithExportConfig;
|
use opentelemetry_otlp::WithExportConfig;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||||
|
|
||||||
|
struct TraceParent {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
version: u8,
|
||||||
|
trace_id: TraceId,
|
||||||
|
parent_id: SpanId,
|
||||||
|
trace_flags: TraceFlags,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
|
||||||
|
let parts: Vec<&str> = header_value.split('-').collect();
|
||||||
|
if parts.len() != 4 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let version = u8::from_str_radix(parts[0], 16).ok()?;
|
||||||
|
if version == 0xff {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let trace_id = TraceId::from_hex(parts[1]).ok()?;
|
||||||
|
let parent_id = SpanId::from_hex(parts[2]).ok()?;
|
||||||
|
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;
|
||||||
|
|
||||||
|
Some(TraceParent {
|
||||||
|
version,
|
||||||
|
trace_id,
|
||||||
|
parent_id,
|
||||||
|
trace_flags: TraceFlags::new(trace_flags),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {
|
||||||
|
let context = request
|
||||||
|
.headers()
|
||||||
|
.get("traceparent")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(parse_traceparent)
|
||||||
|
.map(|traceparent| {
|
||||||
|
Context::new().with_remote_span_context(SpanContext::new(
|
||||||
|
traceparent.trace_id,
|
||||||
|
traceparent.parent_id,
|
||||||
|
traceparent.trace_flags,
|
||||||
|
true,
|
||||||
|
Default::default(),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
|
request.extensions_mut().insert(context);
|
||||||
|
|
||||||
|
next.run(request).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||||
/// - otlp_service_name service name to appear in APM
|
/// - otlp_service_name service name to appear in APM
|
||||||
|
|
|
@ -7,6 +7,7 @@ use crate::kserve::{
|
||||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
|
use crate::logging::trace_context_middleware;
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::vertex::vertex_compatibility;
|
use crate::vertex::vertex_compatibility;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::ChatTokenizeResponse;
|
||||||
|
@ -57,6 +58,7 @@ use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
@ -87,6 +89,7 @@ async fn compat_generate(
|
||||||
Extension(default_return_full_text): Extension<bool>,
|
Extension(default_return_full_text): Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
compute_type: Extension<ComputeType>,
|
compute_type: Extension<ComputeType>,
|
||||||
|
context: Extension<Option<opentelemetry::Context>>,
|
||||||
Json(mut req): Json<CompatGenerateRequest>,
|
Json(mut req): Json<CompatGenerateRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
|
@ -96,11 +99,14 @@ async fn compat_generate(
|
||||||
|
|
||||||
// switch on stream
|
// switch on stream
|
||||||
if req.stream {
|
if req.stream {
|
||||||
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
Ok(
|
||||||
|
generate_stream(infer, compute_type, context, Json(req.into()))
|
||||||
.await
|
.await
|
||||||
.into_response())
|
.into_response(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
let (headers, Json(generation)) =
|
||||||
|
generate(infer, compute_type, context, Json(req.into())).await?;
|
||||||
// wrap generation inside a Vec to match api-inference
|
// wrap generation inside a Vec to match api-inference
|
||||||
Ok((headers, Json(vec![generation])).into_response())
|
Ok((headers, Json(vec![generation])).into_response())
|
||||||
}
|
}
|
||||||
|
@ -251,9 +257,14 @@ seed,
|
||||||
async fn generate(
|
async fn generate(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -447,12 +458,17 @@ seed,
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> (
|
) -> (
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||||
|
|
||||||
|
@ -682,9 +698,14 @@ async fn completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
let CompletionRequest {
|
let CompletionRequest {
|
||||||
|
@ -1206,9 +1227,14 @@ async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(chat): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
stream,
|
stream,
|
||||||
|
@ -2348,6 +2374,7 @@ async fn start(
|
||||||
.layer(Extension(compute_type))
|
.layer(Extension(compute_type))
|
||||||
.layer(Extension(prom_handle.clone()))
|
.layer(Extension(prom_handle.clone()))
|
||||||
.layer(OtelAxumLayer::default())
|
.layer(OtelAxumLayer::default())
|
||||||
|
.layer(axum::middleware::from_fn(trace_context_middleware))
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
|
@ -10,6 +10,7 @@ use axum::response::{IntoResponse, Response};
|
||||||
use axum::Json;
|
use axum::Json;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
@ -223,9 +224,14 @@ example = json ! ({"error": "Incomplete generation"})),
|
||||||
pub(crate) async fn vertex_compatibility(
|
pub(crate) async fn vertex_compatibility(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
if let Some(context) = context {
|
||||||
|
span.set_parent(context);
|
||||||
|
}
|
||||||
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
// check that theres at least one instance
|
// check that theres at least one instance
|
||||||
|
|
Loading…
Reference in New Issue