feat(router): add prometheus metrics scrape endpoint (#71)

This commit is contained in:
OlivierDehaene 2023-02-16 17:18:53 +01:00 committed by GitHub
parent 7b3d460d21
commit 439fcaf810
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 239 additions and 33 deletions

120
Cargo.lock generated
View File

@ -8,6 +8,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "0.7.20" version = "0.7.20"
@ -806,6 +817,9 @@ name = "hashbrown"
version = "0.12.3" version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash",
]
[[package]] [[package]]
name = "heck" name = "heck"
@ -1093,6 +1107,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "mach"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "macro_rules_attribute" name = "macro_rules_attribute"
version = "0.1.3" version = "0.1.3"
@ -1139,6 +1162,64 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "metrics"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]]
name = "metrics-exporter-prometheus"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
dependencies = [
"hyper",
"indexmap",
"ipnet",
"metrics",
"metrics-util",
"parking_lot",
"portable-atomic",
"quanta",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "metrics-macros"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "metrics-util"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
"hashbrown",
"metrics",
"num_cpus",
"parking_lot",
"portable-atomic",
"quanta",
"sketches-ddsketch",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.16" version = "0.3.16"
@ -1514,6 +1595,12 @@ version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "portable-atomic"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.17" version = "0.2.17"
@ -1618,6 +1705,22 @@ dependencies = [
"prost", "prost",
] ]
[[package]]
name = "quanta"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
dependencies = [
"crossbeam-utils",
"libc",
"mach",
"once_cell",
"raw-cpuid",
"wasi 0.10.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.23" version = "1.0.23"
@ -1657,6 +1760,15 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "raw-cpuid"
version = "10.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c307f7aacdbab3f0adee67d52739a1d71112cc068d6fab169ddeb18e48877fad"
dependencies = [
"bitflags",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.6.1" version = "1.6.1"
@ -1980,6 +2092,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "sketches-ddsketch"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.7" version = "0.4.7"
@ -2143,6 +2261,8 @@ dependencies = [
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"clap 4.1.4", "clap 4.1.4",
"futures", "futures",
"metrics",
"metrics-exporter-prometheus",
"nohash-hasher", "nohash-hasher",
"opentelemetry", "opentelemetry",
"opentelemetry-otlp", "opentelemetry-otlp",

View File

@ -19,6 +19,8 @@ axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.1.4", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26" futures = "0.3.26"
metrics = "0.20.1"
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0" opentelemetry-otlp = "0.11.0"

View File

@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest; use crate::GenerateRequest;
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
@ -81,6 +80,7 @@ impl Infer {
.limit_concurrent_requests .limit_concurrent_requests
.try_acquire_owned() .try_acquire_owned()
.map_err(|err| { .map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
tracing::error!("{err}"); tracing::error!("{err}");
err err
})?; })?;
@ -172,6 +172,7 @@ impl Infer {
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}"); tracing::error!("{err}");
Err(err) Err(err)
} }
@ -201,7 +202,7 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await { while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries) let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
@ -212,6 +213,7 @@ async fn batching_task(
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
// If the current batch is too small, we try to add more requests to it // If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size { if batch_size <= limit_min_batch_size {
@ -241,8 +243,7 @@ async fn batching_task(
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span) .instrument(span)
.await; .await;
// Reset waiting counter // Reset waiting counter
@ -268,29 +269,59 @@ async fn batching_task(
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
}); });
cached_batch = wrap_future(client.decode(batches), &mut entries) cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span) .instrument(next_batch_span)
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
} }
metrics::gauge!("tgi_batch_current_size", 0.0);
} }
} }
} }
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip_all)] #[instrument(skip_all)]
async fn wrap_future( async fn prefill(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>, client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> { ) -> Option<Batch> {
match future.await { let start_time = Instant::now();
match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
send_generations(generations, entries); send_generations(generations, entries);
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch next_batch
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<Batch> {
let start_time = Instant::now();
match client.decode(batches).await {
Ok((generations, next_batch)) => {
send_generations(generations, entries);
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
None None
} }
} }
@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry // Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered(); let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string()); let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
tracing::error!("{err}"); tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -132,6 +132,7 @@ impl State {
// Push entry in the queue // Push entry in the queue
self.entries.push((self.next_id, entry)); self.entries.push((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
metrics::increment_gauge!("tgi_queue_size", 1.0);
} }
// Get the next batch // Get the next batch
@ -190,6 +191,8 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
metrics::gauge!("tgi_queue_size", self.entries.len() as f64);
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
Some((batch_entries, batch, next_batch_span)) Some((batch_entries, batch, next_batch_span))
} }
} }

View File

@ -12,6 +12,7 @@ use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream; use futures::Stream;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
@ -57,14 +58,14 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
path = "/generate", path = "/generate",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = [GenerateResponse]), (status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = [ErrorResponse], (status = 424, description = "Generation Error", body = ErrorResponse,
example = json!({"error": "Request failed during generation"})), example = json!({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = [ErrorResponse], (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json!({"error": "Model is overloaded"})), example = json!({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = [ErrorResponse], (status = 422, description = "Input validation error", body = ErrorResponse,
example = json!({"error": "Input validation error"})), example = json!({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = [ErrorResponse], (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json!({"error": "Incomplete generation"})), example = json!({"error": "Incomplete generation"})),
) )
)] )]
@ -141,6 +142,18 @@ async fn generate(
span.record("seed", format!("{:?}", response.generated_text.seed)); span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!(
"tgi_request_generated_tokens",
response.generated_text.generated_tokens as f64
);
// Send response // Send response
let response = GenerateResponse { let response = GenerateResponse {
generated_text: response.generated_text.text, generated_text: response.generated_text.text,
@ -156,20 +169,20 @@ async fn generate(
path = "/generate_stream", path = "/generate_stream",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = [StreamResponse], (status = 200, description = "Generated Text", body = StreamResponse,
content_type="text/event-stream "), content_type="text/event-stream"),
(status = 424, description = "Generation Error", body = [ErrorResponse], (status = 424, description = "Generation Error", body = ErrorResponse,
example = json!({"error": "Request failed during generation"}), example = json!({"error": "Request failed during generation"}),
content_type="text/event-stream "), content_type="text/event-stream"),
(status = 429, description = "Model is overloaded", body = [ErrorResponse], (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json!({"error": "Model is overloaded"}), example = json!({"error": "Model is overloaded"}),
content_type="text/event-stream "), content_type="text/event-stream"),
(status = 422, description = "Input validation error", body = [ErrorResponse], (status = 422, description = "Input validation error", body = ErrorResponse,
example = json!({"error": "Input validation error"}), example = json!({"error": "Input validation error"}),
content_type="text/event-stream "), content_type="text/event-stream"),
(status = 500, description = "Incomplete generation", body = [ErrorResponse], (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json!({"error": "Incomplete generation"}), example = json!({"error": "Incomplete generation"}),
content_type="text/event-stream "), content_type="text/event-stream"),
) )
)] )]
#[instrument( #[instrument(
@ -249,6 +262,15 @@ async fn generate_stream(
span.record("seed", format!("{:?}", generated_text.seed)); span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text); tracing::info!(parent: &span, "Output: {}", generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
// StreamResponse // StreamResponse
end_reached = true; end_reached = true;
let stream_token = StreamResponse { let stream_token = StreamResponse {
@ -279,6 +301,7 @@ async fn generate_stream(
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)) yield Ok(Event::from(err))
} }
@ -287,6 +310,17 @@ async fn generate_stream(
Sse::new(stream).keep_alive(KeepAlive::default()) Sse::new(stream).keep_alive(KeepAlive::default())
} }
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
}
/// Serving method /// Serving method
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn run( pub async fn run(
@ -307,6 +341,7 @@ pub async fn run(
paths( paths(
generate, generate,
generate_stream, generate_stream,
metrics,
), ),
components( components(
schemas( schemas(
@ -350,6 +385,12 @@ pub async fn run(
max_concurrent_requests, max_concurrent_requests,
); );
// Prometheus handler
let builder = PrometheusBuilder::new();
let prom_handle = builder
.install_recorder()
.expect("failed to install metrics recorder");
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
@ -359,6 +400,8 @@ pub async fn run(
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(infer)) .layer(Extension(infer))
.route("/metrics", get(metrics))
.layer(Extension(prom_handle))
.layer(opentelemetry_tracing_layer()); .layer(opentelemetry_tracing_layer());
// Run server // Run server

View File

@ -13,7 +13,7 @@ use tracing::{instrument, Span};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
/// Channel to communicate with the background validation task /// Channel to communicate with the background validation task
sender: mpsc::Sender<ValidationRequest>, sender: mpsc::UnboundedSender<ValidationRequest>,
} }
impl Validation { impl Validation {
@ -25,7 +25,7 @@ impl Validation {
max_total_tokens: usize, max_total_tokens: usize,
) -> Self { ) -> Self {
// Create channel // Create channel
let (validation_sender, validation_receiver) = mpsc::channel(128); let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
// Launch background validation task // Launch background validation task
tokio::spawn(validation_task( tokio::spawn(validation_task(
@ -54,7 +54,6 @@ impl Validation {
// Unwrap is safe here // Unwrap is safe here
self.sender self.sender
.send((request, sender, Span::current())) .send((request, sender, Span::current()))
.await
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
// Unwrap is safe here // Unwrap is safe here
@ -70,7 +69,7 @@ async fn validation_task(
max_stop_sequences: usize, max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
mut receiver: mpsc::Receiver<ValidationRequest>, mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
) { ) {
let mut workers_senders = Vec::with_capacity(workers); let mut workers_senders = Vec::with_capacity(workers);
@ -131,6 +130,7 @@ fn validation_worker(
&mut rng, &mut rng,
) )
.map_err(|err| { .map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
tracing::error!("{err}"); tracing::error!("{err}");
err err
}), }),
@ -214,6 +214,7 @@ fn validate(
Ok(encoding) => { Ok(encoding) => {
let input_length = encoding.len(); let input_length = encoding.len();
let total_tokens = input_length + max_new_tokens as usize; let total_tokens = input_length + max_new_tokens as usize;
if input_length > max_input_length { if input_length > max_input_length {
Err(ValidationError::InputLength(max_input_length, input_length)) Err(ValidationError::InputLength(max_input_length, input_length))
} else if total_tokens > max_total_tokens { } else if total_tokens > max_total_tokens {
@ -237,6 +238,9 @@ fn validate(
stop_sequences, stop_sequences,
}; };
metrics::histogram!("tgi_request_input_length", input_length as f64);
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
Ok(ValidGenerateRequest { Ok(ValidGenerateRequest {
inputs: request.inputs, inputs: request.inputs,
input_length: input_length as u32, input_length: input_length as u32,

View File

@ -49,6 +49,8 @@ def convert_file(pt_file: Path, st_file: Path):
""" """
Convert a pytorch file to a safetensors file Convert a pytorch file to a safetensors file
""" """
logger.info(f"Convert {pt_file} to {st_file}.")
pt_state = torch.load(pt_file, map_location="cpu") pt_state = torch.load(pt_file, map_location="cpu")
if "state_dict" in pt_state: if "state_dict" in pt_state:
pt_state = pt_state["state_dict"] pt_state = pt_state["state_dict"]

View File

@ -132,9 +132,9 @@ def download_weights(
local_file = try_to_load_from_cache(model_id, revision, filename) local_file = try_to_load_from_cache(model_id, revision, filename)
if local_file is not None: if local_file is not None:
logger.info(f"File {filename} already present in cache.") logger.info(f"File {filename} already present in cache.")
return local_file return Path(local_file)
logger.info(f"Starting {filename} download.") logger.info(f"Download file: {filename}")
start_time = time.time() start_time = time.time()
local_file = hf_hub_download( local_file = hf_hub_download(
filename=filename, filename=filename,
@ -143,7 +143,7 @@ def download_weights(
local_files_only=False, local_files_only=False,
) )
logger.info( logger.info(
f"Downloaded {filename} at {local_file} in {timedelta(seconds=int(time.time() - start_time))}." f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
) )
return Path(local_file) return Path(local_file)