feature: get trace id from req headers
This commit is contained in:
parent
58848cb471
commit
64b0337574
|
@ -1,13 +1,68 @@
|
|||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::trace::Sampler;
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
|
||||
use opentelemetry::Context;
|
||||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||
|
||||
struct TraceParent {
|
||||
#[allow(dead_code)]
|
||||
version: u8,
|
||||
trace_id: TraceId,
|
||||
parent_id: SpanId,
|
||||
trace_flags: TraceFlags,
|
||||
}
|
||||
|
||||
fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
|
||||
let parts: Vec<&str> = header_value.split('-').collect();
|
||||
if parts.len() != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let version = u8::from_str_radix(parts[0], 16).ok()?;
|
||||
if version == 0xff {
|
||||
return None;
|
||||
}
|
||||
|
||||
let trace_id = TraceId::from_hex(parts[1]).ok()?;
|
||||
let parent_id = SpanId::from_hex(parts[2]).ok()?;
|
||||
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;
|
||||
|
||||
Some(TraceParent {
|
||||
version,
|
||||
trace_id,
|
||||
parent_id,
|
||||
trace_flags: TraceFlags::new(trace_flags),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {
|
||||
let context = request
|
||||
.headers()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(parse_traceparent)
|
||||
.map(|traceparent| {
|
||||
Context::new().with_remote_span_context(SpanContext::new(
|
||||
traceparent.trace_id,
|
||||
traceparent.parent_id,
|
||||
traceparent.trace_flags,
|
||||
true,
|
||||
Default::default(),
|
||||
))
|
||||
});
|
||||
|
||||
request.extensions_mut().insert(context);
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - otlp_service_name service name to appear in APM
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::kserve::{
|
|||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::logging::trace_context_middleware;
|
||||
use crate::validation::ValidationError;
|
||||
use crate::vertex::vertex_compatibility;
|
||||
use crate::ChatTokenizeResponse;
|
||||
|
@ -57,6 +58,7 @@ use tokio::sync::oneshot;
|
|||
use tokio::time::Instant;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
|
@ -87,6 +89,7 @@ async fn compat_generate(
|
|||
Extension(default_return_full_text): Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
compute_type: Extension<ComputeType>,
|
||||
context: Extension<Option<opentelemetry::Context>>,
|
||||
Json(mut req): Json<CompatGenerateRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
// default return_full_text given the pipeline_tag
|
||||
|
@ -96,11 +99,14 @@ async fn compat_generate(
|
|||
|
||||
// switch on stream
|
||||
if req.stream {
|
||||
Ok(generate_stream(infer, compute_type, Json(req.into()))
|
||||
.await
|
||||
.into_response())
|
||||
Ok(
|
||||
generate_stream(infer, compute_type, context, Json(req.into()))
|
||||
.await
|
||||
.into_response(),
|
||||
)
|
||||
} else {
|
||||
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
|
||||
let (headers, Json(generation)) =
|
||||
generate(infer, compute_type, context, Json(req.into())).await?;
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
Ok((headers, Json(vec![generation])).into_response())
|
||||
}
|
||||
|
@ -251,9 +257,14 @@ seed,
|
|||
async fn generate(
|
||||
infer: Extension<Infer>,
|
||||
Extension(ComputeType(compute_type)): Extension<ComputeType>,
|
||||
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
if let Some(context) = context {
|
||||
span.set_parent(context);
|
||||
}
|
||||
|
||||
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||
}
|
||||
|
||||
|
@ -447,12 +458,17 @@ seed,
|
|||
async fn generate_stream(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||
Json(req): Json<GenerateRequest>,
|
||||
) -> (
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
if let Some(context) = context {
|
||||
span.set_parent(context);
|
||||
}
|
||||
|
||||
let (headers, response_stream) =
|
||||
generate_stream_internal(infer, compute_type, Json(req), span).await;
|
||||
|
||||
|
@ -682,9 +698,14 @@ async fn completions(
|
|||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
if let Some(context) = context {
|
||||
span.set_parent(context);
|
||||
}
|
||||
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
let CompletionRequest {
|
||||
|
@ -1206,9 +1227,14 @@ async fn chat_completions(
|
|||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(info): Extension<Info>,
|
||||
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||
Json(chat): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
if let Some(context) = context {
|
||||
span.set_parent(context);
|
||||
}
|
||||
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
let ChatRequest {
|
||||
stream,
|
||||
|
@ -2348,6 +2374,7 @@ async fn start(
|
|||
.layer(Extension(compute_type))
|
||||
.layer(Extension(prom_handle.clone()))
|
||||
.layer(OtelAxumLayer::default())
|
||||
.layer(axum::middleware::from_fn(trace_context_middleware))
|
||||
.layer(cors_layer);
|
||||
|
||||
tracing::info!("Connected");
|
||||
|
|
|
@ -10,6 +10,7 @@ use axum::response::{IntoResponse, Response};
|
|||
use axum::Json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::instrument;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
|
@ -223,9 +224,14 @@ example = json ! ({"error": "Incomplete generation"})),
|
|||
pub(crate) async fn vertex_compatibility(
|
||||
Extension(infer): Extension<Infer>,
|
||||
Extension(compute_type): Extension<ComputeType>,
|
||||
Extension(context): Extension<Option<opentelemetry::Context>>,
|
||||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
if let Some(context) = context {
|
||||
span.set_parent(context);
|
||||
}
|
||||
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// check that theres at least one instance
|
||||
|
|
Loading…
Reference in New Issue