This commit is contained in:
Kevin Duffy 2024-06-17 11:24:51 +01:00
parent e903770897
commit 6e93482c46
4 changed files with 29 additions and 9 deletions

View File

@ -413,6 +413,9 @@ struct Args {
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
@ -483,6 +486,7 @@ fn shard_manager(
max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -548,12 +552,18 @@ fn shard_manager(
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
// OpenTelemetry Endpoint
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
shard_args.push(otlp_endpoint);
}
// OpenTelemetry Service Name
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-service-name".to_string());
shard_args.push(otlp_service_name);
}
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
@ -1035,6 +1045,7 @@ fn spawn_shards(
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
@ -1074,6 +1085,7 @@ fn spawn_shards(
max_batch_size,
max_input_tokens,
otlp_endpoint,
otlp_service_name,
max_log_level,
status_sender,
shutdown,
@ -1207,6 +1219,11 @@ fn spawn_webserver(
router_args.push(otlp_endpoint);
}
// OpenTelemetry
if args.otlp_service_name {
router_args.push("--otlp-service-name".to_string());
}
// CORS origins
for origin in args.cors_allow_origin.into_iter() {
router_args.push("--cors-allow-origin".to_string());

View File

@ -65,6 +65,8 @@ struct Args {
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
@ -107,6 +109,7 @@ async fn main() -> Result<(), RouterError> {
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
@ -117,7 +120,7 @@ async fn main() -> Result<(), RouterError> {
} = args;
// Launch Tokio runtime
init_logging(otlp_endpoint, json_output);
init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
@ -367,10 +370,11 @@ async fn main() -> Result<(), RouterError> {
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
@ -401,7 +405,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"text-generation-inference.router",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)

View File

@ -42,6 +42,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
if sharded:
@ -76,7 +77,7 @@ def serve(
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value

View File

@ -54,10 +54,8 @@ class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
)
def setup_tracing(shard: int, otlp_endpoint: str):
resource = Resource.create(
attributes={"service.name": f"text-generation-inference.server-{shard}"}
)
def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
resource = Resource.create(attributes={"service.name": otlp_service_name})
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)