feat(router): add prometheus metrics scrape endpoint (#71)
This commit is contained in:
parent
7b3d460d21
commit
439fcaf810
|
@ -8,6 +8,17 @@ version = "1.0.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.20"
|
||||
|
@ -806,6 +817,9 @@ name = "hashbrown"
|
|||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
|
@ -1093,6 +1107,15 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mach"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "macro_rules_attribute"
|
||||
version = "0.1.3"
|
||||
|
@ -1139,6 +1162,64 @@ dependencies = [
|
|||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.20.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"metrics-macros",
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-exporter-prometheus"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"indexmap",
|
||||
"ipnet",
|
||||
"metrics",
|
||||
"metrics-util",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-macros"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-util"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
|
||||
dependencies = [
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
"hashbrown",
|
||||
"metrics",
|
||||
"num_cpus",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"sketches-ddsketch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.16"
|
||||
|
@ -1514,6 +1595,12 @@ version = "0.3.26"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "0.3.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
|
@ -1618,6 +1705,22 @@ dependencies = [
|
|||
"prost",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"mach",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi 0.10.0+wasi-snapshot-preview1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.23"
|
||||
|
@ -1657,6 +1760,15 @@ dependencies = [
|
|||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "10.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c307f7aacdbab3f0adee67d52739a1d71112cc068d6fab169ddeb18e48877fad"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.6.1"
|
||||
|
@ -1980,6 +2092,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sketches-ddsketch"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.7"
|
||||
|
@ -2143,6 +2261,8 @@ dependencies = [
|
|||
"axum-tracing-opentelemetry",
|
||||
"clap 4.1.4",
|
||||
"futures",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"nohash-hasher",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
|
|
|
@ -19,6 +19,8 @@ axum-tracing-opentelemetry = "0.9.0"
|
|||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||
futures = "0.3.26"
|
||||
metrics = "0.20.1"
|
||||
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.11.0"
|
||||
|
|
|
@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError};
|
|||
use crate::GenerateRequest;
|
||||
use crate::{Entry, Queue, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
|
@ -81,6 +80,7 @@ impl Infer {
|
|||
.limit_concurrent_requests
|
||||
.try_acquire_owned()
|
||||
.map_err(|err| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
})?;
|
||||
|
@ -172,6 +172,7 @@ impl Infer {
|
|||
})
|
||||
} else {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||
tracing::error!("{err}");
|
||||
Err(err)
|
||||
}
|
||||
|
@ -201,7 +202,7 @@ async fn batching_task(
|
|||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
|
||||
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
@ -212,6 +213,7 @@ async fn batching_task(
|
|||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
|
||||
// If the current batch is too small, we try to add more requests to it
|
||||
if batch_size <= limit_min_batch_size {
|
||||
|
@ -241,10 +243,9 @@ async fn batching_task(
|
|||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
wrap_future(client.prefill(new_batch), &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
|
@ -268,29 +269,59 @@ async fn batching_task(
|
|||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = wrap_future(client.decode(batches), &mut entries)
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
|
||||
#[instrument(skip_all)]
|
||||
async fn wrap_future(
|
||||
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
match future.await {
|
||||
let start_time = Instant::now();
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
send_generations(generations, entries);
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<Batch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<Batch> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
send_generations(generations, entries);
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
|
|
|
@ -132,6 +132,7 @@ impl State {
|
|||
// Push entry in the queue
|
||||
self.entries.push((self.next_id, entry));
|
||||
self.next_id += 1;
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
}
|
||||
|
||||
// Get the next batch
|
||||
|
@ -190,6 +191,8 @@ impl State {
|
|||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
|
||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ use axum::routing::{get, post};
|
|||
use axum::{Json, Router};
|
||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use futures::Stream;
|
||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use text_generation_client::ShardedClient;
|
||||
|
@ -57,14 +58,14 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
|||
path = "/generate",
|
||||
request_body = GenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = [GenerateResponse]),
|
||||
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json!({"error": "Request failed during generation"})),
|
||||
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json!({"error": "Model is overloaded"})),
|
||||
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json!({"error": "Input validation error"})),
|
||||
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json!({"error": "Incomplete generation"})),
|
||||
)
|
||||
)]
|
||||
|
@ -141,6 +142,18 @@ async fn generate(
|
|||
span.record("seed", format!("{:?}", response.generated_text.seed));
|
||||
tracing::info!("Output: {}", response.generated_text.text);
|
||||
|
||||
// Metrics
|
||||
metrics::increment_counter!("tgi_request_success");
|
||||
metrics::histogram!("tgi_request_duration", total_time);
|
||||
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||
metrics::histogram!(
|
||||
"tgi_request_generated_tokens",
|
||||
response.generated_text.generated_tokens as f64
|
||||
);
|
||||
|
||||
// Send response
|
||||
let response = GenerateResponse {
|
||||
generated_text: response.generated_text.text,
|
||||
|
@ -156,20 +169,20 @@ async fn generate(
|
|||
path = "/generate_stream",
|
||||
request_body = GenerateRequest,
|
||||
responses(
|
||||
(status = 200, description = "Generated Text", body = [StreamResponse],
|
||||
content_type="text/event-stream "),
|
||||
(status = 424, description = "Generation Error", body = [ErrorResponse],
|
||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||
content_type="text/event-stream"),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json!({"error": "Request failed during generation"}),
|
||||
content_type="text/event-stream "),
|
||||
(status = 429, description = "Model is overloaded", body = [ErrorResponse],
|
||||
content_type="text/event-stream"),
|
||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||
example = json!({"error": "Model is overloaded"}),
|
||||
content_type="text/event-stream "),
|
||||
(status = 422, description = "Input validation error", body = [ErrorResponse],
|
||||
content_type="text/event-stream"),
|
||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||
example = json!({"error": "Input validation error"}),
|
||||
content_type="text/event-stream "),
|
||||
(status = 500, description = "Incomplete generation", body = [ErrorResponse],
|
||||
content_type="text/event-stream"),
|
||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||
example = json!({"error": "Incomplete generation"}),
|
||||
content_type="text/event-stream "),
|
||||
content_type="text/event-stream"),
|
||||
)
|
||||
)]
|
||||
#[instrument(
|
||||
|
@ -249,6 +262,15 @@ async fn generate_stream(
|
|||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||
|
||||
// Metrics
|
||||
metrics::increment_counter!("tgi_request_success");
|
||||
metrics::histogram!("tgi_request_duration", total_time);
|
||||
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||
|
||||
// StreamResponse
|
||||
end_reached = true;
|
||||
let stream_token = StreamResponse {
|
||||
|
@ -279,6 +301,7 @@ async fn generate_stream(
|
|||
// Skip if we already sent an error
|
||||
if !end_reached && !error {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err))
|
||||
}
|
||||
|
@ -287,6 +310,17 @@ async fn generate_stream(
|
|||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
)]
|
||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
}
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
|
@ -307,6 +341,7 @@ pub async fn run(
|
|||
paths(
|
||||
generate,
|
||||
generate_stream,
|
||||
metrics,
|
||||
),
|
||||
components(
|
||||
schemas(
|
||||
|
@ -350,6 +385,12 @@ pub async fn run(
|
|||
max_concurrent_requests,
|
||||
);
|
||||
|
||||
// Prometheus handler
|
||||
let builder = PrometheusBuilder::new();
|
||||
let prom_handle = builder
|
||||
.install_recorder()
|
||||
.expect("failed to install metrics recorder");
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
|
@ -359,6 +400,8 @@ pub async fn run(
|
|||
.route("/", get(health))
|
||||
.route("/health", get(health))
|
||||
.layer(Extension(infer))
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(prom_handle))
|
||||
.layer(opentelemetry_tracing_layer());
|
||||
|
||||
// Run server
|
||||
|
|
|
@ -13,7 +13,7 @@ use tracing::{instrument, Span};
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct Validation {
|
||||
/// Channel to communicate with the background validation task
|
||||
sender: mpsc::Sender<ValidationRequest>,
|
||||
sender: mpsc::UnboundedSender<ValidationRequest>,
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
|
@ -25,7 +25,7 @@ impl Validation {
|
|||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (validation_sender, validation_receiver) = mpsc::channel(128);
|
||||
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background validation task
|
||||
tokio::spawn(validation_task(
|
||||
|
@ -54,7 +54,6 @@ impl Validation {
|
|||
// Unwrap is safe here
|
||||
self.sender
|
||||
.send((request, sender, Span::current()))
|
||||
.await
|
||||
.unwrap();
|
||||
// Await on response channel
|
||||
// Unwrap is safe here
|
||||
|
@ -70,7 +69,7 @@ async fn validation_task(
|
|||
max_stop_sequences: usize,
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
||||
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
|
||||
) {
|
||||
let mut workers_senders = Vec::with_capacity(workers);
|
||||
|
||||
|
@ -131,6 +130,7 @@ fn validation_worker(
|
|||
&mut rng,
|
||||
)
|
||||
.map_err(|err| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}),
|
||||
|
@ -214,6 +214,7 @@ fn validate(
|
|||
Ok(encoding) => {
|
||||
let input_length = encoding.len();
|
||||
let total_tokens = input_length + max_new_tokens as usize;
|
||||
|
||||
if input_length > max_input_length {
|
||||
Err(ValidationError::InputLength(max_input_length, input_length))
|
||||
} else if total_tokens > max_total_tokens {
|
||||
|
@ -237,6 +238,9 @@ fn validate(
|
|||
stop_sequences,
|
||||
};
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs: request.inputs,
|
||||
input_length: input_length as u32,
|
||||
|
|
|
@ -49,6 +49,8 @@ def convert_file(pt_file: Path, st_file: Path):
|
|||
"""
|
||||
Convert a pytorch file to a safetensors file
|
||||
"""
|
||||
logger.info(f"Convert {pt_file} to {st_file}.")
|
||||
|
||||
pt_state = torch.load(pt_file, map_location="cpu")
|
||||
if "state_dict" in pt_state:
|
||||
pt_state = pt_state["state_dict"]
|
||||
|
|
|
@ -132,9 +132,9 @@ def download_weights(
|
|||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||
if local_file is not None:
|
||||
logger.info(f"File {filename} already present in cache.")
|
||||
return local_file
|
||||
return Path(local_file)
|
||||
|
||||
logger.info(f"Starting {filename} download.")
|
||||
logger.info(f"Download file: {filename}")
|
||||
start_time = time.time()
|
||||
local_file = hf_hub_download(
|
||||
filename=filename,
|
||||
|
@ -143,7 +143,7 @@ def download_weights(
|
|||
local_files_only=False,
|
||||
)
|
||||
logger.info(
|
||||
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||
)
|
||||
return Path(local_file)
|
||||
|
||||
|
|
Loading…
Reference in New Issue