feat(router): add prometheus metrics scrape endpoint (#71)
This commit is contained in:
parent
7b3d460d21
commit
439fcaf810
|
@ -8,6 +8,17 @@ version = "1.0.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ahash"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
"once_cell",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aho-corasick"
|
name = "aho-corasick"
|
||||||
version = "0.7.20"
|
version = "0.7.20"
|
||||||
|
@ -806,6 +817,9 @@ name = "hashbrown"
|
||||||
version = "0.12.3"
|
version = "0.12.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
|
@ -1093,6 +1107,15 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mach"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "macro_rules_attribute"
|
name = "macro_rules_attribute"
|
||||||
version = "0.1.3"
|
version = "0.1.3"
|
||||||
|
@ -1139,6 +1162,64 @@ dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics"
|
||||||
|
version = "0.20.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
"metrics-macros",
|
||||||
|
"portable-atomic",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics-exporter-prometheus"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
|
||||||
|
dependencies = [
|
||||||
|
"hyper",
|
||||||
|
"indexmap",
|
||||||
|
"ipnet",
|
||||||
|
"metrics",
|
||||||
|
"metrics-util",
|
||||||
|
"parking_lot",
|
||||||
|
"portable-atomic",
|
||||||
|
"quanta",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics-macros"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "metrics-util"
|
||||||
|
version = "0.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"hashbrown",
|
||||||
|
"metrics",
|
||||||
|
"num_cpus",
|
||||||
|
"parking_lot",
|
||||||
|
"portable-atomic",
|
||||||
|
"quanta",
|
||||||
|
"sketches-ddsketch",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mime"
|
name = "mime"
|
||||||
version = "0.3.16"
|
version = "0.3.16"
|
||||||
|
@ -1514,6 +1595,12 @@ version = "0.3.26"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portable-atomic"
|
||||||
|
version = "0.3.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ppv-lite86"
|
name = "ppv-lite86"
|
||||||
version = "0.2.17"
|
version = "0.2.17"
|
||||||
|
@ -1618,6 +1705,22 @@ dependencies = [
|
||||||
"prost",
|
"prost",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quanta"
|
||||||
|
version = "0.10.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
"libc",
|
||||||
|
"mach",
|
||||||
|
"once_cell",
|
||||||
|
"raw-cpuid",
|
||||||
|
"wasi 0.10.0+wasi-snapshot-preview1",
|
||||||
|
"web-sys",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.23"
|
version = "1.0.23"
|
||||||
|
@ -1657,6 +1760,15 @@ dependencies = [
|
||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "raw-cpuid"
|
||||||
|
version = "10.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c307f7aacdbab3f0adee67d52739a1d71112cc068d6fab169ddeb18e48877fad"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon"
|
name = "rayon"
|
||||||
version = "1.6.1"
|
version = "1.6.1"
|
||||||
|
@ -1980,6 +2092,12 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sketches-ddsketch"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "slab"
|
||||||
version = "0.4.7"
|
version = "0.4.7"
|
||||||
|
@ -2143,6 +2261,8 @@ dependencies = [
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"clap 4.1.4",
|
"clap 4.1.4",
|
||||||
"futures",
|
"futures",
|
||||||
|
"metrics",
|
||||||
|
"metrics-exporter-prometheus",
|
||||||
"nohash-hasher",
|
"nohash-hasher",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
|
|
|
@ -19,6 +19,8 @@ axum-tracing-opentelemetry = "0.9.0"
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
futures = "0.3.26"
|
futures = "0.3.26"
|
||||||
|
metrics = "0.20.1"
|
||||||
|
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
|
|
|
@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError};
|
||||||
use crate::GenerateRequest;
|
use crate::GenerateRequest;
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::future::Future;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
|
@ -81,6 +80,7 @@ impl Infer {
|
||||||
.limit_concurrent_requests
|
.limit_concurrent_requests
|
||||||
.try_acquire_owned()
|
.try_acquire_owned()
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
|
@ -172,6 +172,7 @@ impl Infer {
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
|
@ -201,7 +202,7 @@ async fn batching_task(
|
||||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
// waiting in the queue
|
// waiting in the queue
|
||||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
||||||
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
@ -212,6 +213,7 @@ async fn batching_task(
|
||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||||
|
|
||||||
// If the current batch is too small, we try to add more requests to it
|
// If the current batch is too small, we try to add more requests to it
|
||||||
if batch_size <= limit_min_batch_size {
|
if batch_size <= limit_min_batch_size {
|
||||||
|
@ -241,8 +243,7 @@ async fn batching_task(
|
||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch =
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
|
@ -268,29 +269,59 @@ async fn batching_task(
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
}
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
async fn wrap_future(
|
async fn prefill(
|
||||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
match future.await {
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
send_generations(generations, entries);
|
send_generations(generations, entries);
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||||
next_batch
|
next_batch
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<Batch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<Batch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch)) => {
|
||||||
|
send_generations(generations, entries);
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
// Create and enter a span to link this function back to the entry
|
// Create and enter a span to link this function back to the entry
|
||||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
let err = InferError::GenerationError(error.to_string());
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
|
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
|
|
@ -132,6 +132,7 @@ impl State {
|
||||||
// Push entry in the queue
|
// Push entry in the queue
|
||||||
self.entries.push((self.next_id, entry));
|
self.entries.push((self.next_id, entry));
|
||||||
self.next_id += 1;
|
self.next_id += 1;
|
||||||
|
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
|
@ -190,6 +191,8 @@ impl State {
|
||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
|
|
||||||
|
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||||
|
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||||
Some((batch_entries, batch, next_batch_span))
|
Some((batch_entries, batch, next_batch_span))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{Json, Router};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
|
@ -57,14 +58,14 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
path = "/generate",
|
path = "/generate",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = [GenerateResponse]),
|
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||||
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json!({"error": "Request failed during generation"})),
|
example = json!({"error": "Request failed during generation"})),
|
||||||
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json!({"error": "Model is overloaded"})),
|
example = json!({"error": "Model is overloaded"})),
|
||||||
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json!({"error": "Input validation error"})),
|
example = json!({"error": "Input validation error"})),
|
||||||
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json!({"error": "Incomplete generation"})),
|
example = json!({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
|
@ -141,6 +142,18 @@ async fn generate(
|
||||||
span.record("seed", format!("{:?}", response.generated_text.seed));
|
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||||
tracing::info!("Output: {}", response.generated_text.text);
|
tracing::info!("Output: {}", response.generated_text.text);
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
metrics::increment_counter!("tgi_request_success");
|
||||||
|
metrics::histogram!("tgi_request_duration", total_time);
|
||||||
|
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||||
|
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||||
|
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||||
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||||
|
metrics::histogram!(
|
||||||
|
"tgi_request_generated_tokens",
|
||||||
|
response.generated_text.generated_tokens as f64
|
||||||
|
);
|
||||||
|
|
||||||
// Send response
|
// Send response
|
||||||
let response = GenerateResponse {
|
let response = GenerateResponse {
|
||||||
generated_text: response.generated_text.text,
|
generated_text: response.generated_text.text,
|
||||||
|
@ -156,20 +169,20 @@ async fn generate(
|
||||||
path = "/generate_stream",
|
path = "/generate_stream",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = [StreamResponse],
|
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||||
content_type="text/event-stream "),
|
content_type="text/event-stream"),
|
||||||
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json!({"error": "Request failed during generation"}),
|
example = json!({"error": "Request failed during generation"}),
|
||||||
content_type="text/event-stream "),
|
content_type="text/event-stream"),
|
||||||
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json!({"error": "Model is overloaded"}),
|
example = json!({"error": "Model is overloaded"}),
|
||||||
content_type="text/event-stream "),
|
content_type="text/event-stream"),
|
||||||
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json!({"error": "Input validation error"}),
|
example = json!({"error": "Input validation error"}),
|
||||||
content_type="text/event-stream "),
|
content_type="text/event-stream"),
|
||||||
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json!({"error": "Incomplete generation"}),
|
example = json!({"error": "Incomplete generation"}),
|
||||||
content_type="text/event-stream "),
|
content_type="text/event-stream"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
|
@ -249,6 +262,15 @@ async fn generate_stream(
|
||||||
span.record("seed", format!("{:?}", generated_text.seed));
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
metrics::increment_counter!("tgi_request_success");
|
||||||
|
metrics::histogram!("tgi_request_duration", total_time);
|
||||||
|
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||||
|
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||||
|
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||||
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||||
|
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||||
|
|
||||||
// StreamResponse
|
// StreamResponse
|
||||||
end_reached = true;
|
end_reached = true;
|
||||||
let stream_token = StreamResponse {
|
let stream_token = StreamResponse {
|
||||||
|
@ -279,6 +301,7 @@ async fn generate_stream(
|
||||||
// Skip if we already sent an error
|
// Skip if we already sent an error
|
||||||
if !end_reached && !error {
|
if !end_reached && !error {
|
||||||
let err = InferError::IncompleteGeneration;
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err))
|
yield Ok(Event::from(err))
|
||||||
}
|
}
|
||||||
|
@ -287,6 +310,17 @@ async fn generate_stream(
|
||||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Prometheus metrics scrape endpoint
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/metrics",
|
||||||
|
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||||
|
)]
|
||||||
|
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
|
prom_handle.render()
|
||||||
|
}
|
||||||
|
|
||||||
/// Serving method
|
/// Serving method
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
|
@ -307,6 +341,7 @@ pub async fn run(
|
||||||
paths(
|
paths(
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
|
metrics,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
|
@ -350,6 +385,12 @@ pub async fn run(
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Prometheus handler
|
||||||
|
let builder = PrometheusBuilder::new();
|
||||||
|
let prom_handle = builder
|
||||||
|
.install_recorder()
|
||||||
|
.expect("failed to install metrics recorder");
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
|
@ -359,6 +400,8 @@ pub async fn run(
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
|
.route("/metrics", get(metrics))
|
||||||
|
.layer(Extension(prom_handle))
|
||||||
.layer(opentelemetry_tracing_layer());
|
.layer(opentelemetry_tracing_layer());
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
|
|
|
@ -13,7 +13,7 @@ use tracing::{instrument, Span};
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
/// Channel to communicate with the background validation task
|
/// Channel to communicate with the background validation task
|
||||||
sender: mpsc::Sender<ValidationRequest>,
|
sender: mpsc::UnboundedSender<ValidationRequest>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
|
@ -25,7 +25,7 @@ impl Validation {
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
// Launch background validation task
|
// Launch background validation task
|
||||||
tokio::spawn(validation_task(
|
tokio::spawn(validation_task(
|
||||||
|
@ -54,7 +54,6 @@ impl Validation {
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
self.sender
|
self.sender
|
||||||
.send((request, sender, Span::current()))
|
.send((request, sender, Span::current()))
|
||||||
.await
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
|
@ -70,7 +69,7 @@ async fn validation_task(
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
|
||||||
) {
|
) {
|
||||||
let mut workers_senders = Vec::with_capacity(workers);
|
let mut workers_senders = Vec::with_capacity(workers);
|
||||||
|
|
||||||
|
@ -131,6 +130,7 @@ fn validation_worker(
|
||||||
&mut rng,
|
&mut rng,
|
||||||
)
|
)
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
err
|
err
|
||||||
}),
|
}),
|
||||||
|
@ -214,6 +214,7 @@ fn validate(
|
||||||
Ok(encoding) => {
|
Ok(encoding) => {
|
||||||
let input_length = encoding.len();
|
let input_length = encoding.len();
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
|
||||||
if input_length > max_input_length {
|
if input_length > max_input_length {
|
||||||
Err(ValidationError::InputLength(max_input_length, input_length))
|
Err(ValidationError::InputLength(max_input_length, input_length))
|
||||||
} else if total_tokens > max_total_tokens {
|
} else if total_tokens > max_total_tokens {
|
||||||
|
@ -237,6 +238,9 @@ fn validate(
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
|
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs: request.inputs,
|
inputs: request.inputs,
|
||||||
input_length: input_length as u32,
|
input_length: input_length as u32,
|
||||||
|
|
|
@ -49,6 +49,8 @@ def convert_file(pt_file: Path, st_file: Path):
|
||||||
"""
|
"""
|
||||||
Convert a pytorch file to a safetensors file
|
Convert a pytorch file to a safetensors file
|
||||||
"""
|
"""
|
||||||
|
logger.info(f"Convert {pt_file} to {st_file}.")
|
||||||
|
|
||||||
pt_state = torch.load(pt_file, map_location="cpu")
|
pt_state = torch.load(pt_file, map_location="cpu")
|
||||||
if "state_dict" in pt_state:
|
if "state_dict" in pt_state:
|
||||||
pt_state = pt_state["state_dict"]
|
pt_state = pt_state["state_dict"]
|
||||||
|
|
|
@ -132,9 +132,9 @@ def download_weights(
|
||||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||||
if local_file is not None:
|
if local_file is not None:
|
||||||
logger.info(f"File {filename} already present in cache.")
|
logger.info(f"File {filename} already present in cache.")
|
||||||
return local_file
|
return Path(local_file)
|
||||||
|
|
||||||
logger.info(f"Starting {filename} download.")
|
logger.info(f"Download file: {filename}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
local_file = hf_hub_download(
|
local_file = hf_hub_download(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
|
@ -143,7 +143,7 @@ def download_weights(
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||||
)
|
)
|
||||||
return Path(local_file)
|
return Path(local_file)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue