Merge commit 'refs/pull/2076/head' of github.com:huggingface/text-generation-inference into main

This commit is contained in:
drbh 2024-06-20 00:41:54 +00:00
commit 75dccec40d
7 changed files with 41 additions and 10 deletions

View File

@ -153,7 +153,8 @@ this will impact performance.
### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
by setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be
overridden with the `--otlp-service-name` argument
### Architecture

View File

@ -70,6 +70,8 @@ Options:
[env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=]
--otlp-service-name <OTLP_SERVICE_NAME>
[env: OTLP_SERVICE_NAME=]
--cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=]
--ngrok
@ -138,6 +140,8 @@ Serve's command line parameters on the TGI repository are these:
│ --logger-level TEXT [default: INFO] │
│ --json-output --no-json-output [default: no-json-output] │
│ --otlp-endpoint TEXT [default: None] │
│ --otlp-service-name TEXT [default: │
│ text-generation-inference...│
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
```

View File

@ -336,6 +336,12 @@ Options:
--otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=]
```
## OTLP_SERVICE_NAME
```shell
--otlp-service-name <OTLP_SERVICE_NAME>
[env: OTLP_SERVICE_NAME=]
```
## CORS_ALLOW_ORIGIN
```shell

View File

@ -413,6 +413,9 @@ struct Args {
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
#[clap(long, env)]
@ -483,6 +486,7 @@ fn shard_manager(
max_batch_size: Option<usize>,
max_input_tokens: usize,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
@ -548,12 +552,18 @@ fn shard_manager(
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
// OpenTelemetry Endpoint
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
shard_args.push(otlp_endpoint);
}
// OpenTelemetry Service Name
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-service-name".to_string());
shard_args.push(otlp_service_name);
}
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
@ -1035,6 +1045,7 @@ fn spawn_shards(
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone();
let otlp_service_name = args.otlp_service_name.clone();
let quantize = args.quantize;
let speculate = args.speculate;
let dtype = args.dtype;
@ -1074,6 +1085,7 @@ fn spawn_shards(
max_batch_size,
max_input_tokens,
otlp_endpoint,
otlp_service_name,
max_log_level,
status_sender,
shutdown,
@ -1207,6 +1219,11 @@ fn spawn_webserver(
router_args.push(otlp_endpoint);
}
// OpenTelemetry
if args.otlp_service_name {
router_args.push("--otlp-service-name".to_string());
}
// CORS origins
for origin in args.cors_allow_origin.into_iter() {
router_args.push("--cors-allow-origin".to_string());

View File

@ -65,6 +65,8 @@ struct Args {
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
@ -107,6 +109,7 @@ async fn main() -> Result<(), RouterError> {
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
@ -117,7 +120,7 @@ async fn main() -> Result<(), RouterError> {
} = args;
// Launch Tokio runtime
init_logging(otlp_endpoint, json_output);
init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
@ -367,10 +370,11 @@ async fn main() -> Result<(), RouterError> {
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
@ -401,7 +405,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"text-generation-inference.router",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)

View File

@ -42,6 +42,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
if sharded:
@ -76,7 +77,7 @@ def serve(
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value

View File

@ -54,10 +54,8 @@ class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):
)
def setup_tracing(shard: int, otlp_endpoint: str):
resource = Resource.create(
attributes={"service.name": f"text-generation-inference.server-{shard}"}
)
def setup_tracing(otlp_service_name: str, otlp_endpoint: str):
resource = Resource.create(attributes={"service.name": otlp_service_name})
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)
span_processor = BatchSpanProcessor(span_exporter)