From 954653466d24a9b3435988136983398bdf788a2f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 23 May 2024 15:40:40 +0200 Subject: [PATCH] Improving the logging system. (#1938) - Added a debug log for speculated ids (helps seeing in logs quality of a speculator). - Remove newlines from child process logs when re-emitting in non JSON mode. - Made standard level be closer to what's expected (only our binaries level). - Propagate that level correctly to the shard (was forced into INFO). # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- launcher/src/main.rs | 40 ++++++++++++++----- router/src/main.rs | 19 +++++++-- .../models/flash_causal_lm.py | 5 +++ 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d74fca64..a97a75c0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -17,7 +17,7 @@ use std::thread::sleep; use std::time::{Duration, Instant}; use std::{fs, io}; use thiserror::Error; -use tracing_subscriber::EnvFilter; +use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; @@ -470,6 +470,7 @@ fn shard_manager( max_total_tokens: usize, max_batch_size: Option, otlp_endpoint: Option, + log_level: LevelFilter, status_sender: mpsc::Sender, shutdown: Arc, _shutdown_sender: mpsc::Sender<()>, @@ -492,7 +493,7 @@ fn shard_manager( "--uds-path".to_string(), uds_path, "--logger-level".to_string(), - "INFO".to_string(), + log_level.to_string().to_uppercase(), "--json-output".to_string(), ]; @@ -770,13 +771,13 @@ struct PythonLogMessage { impl PythonLogMessage { fn trace(&self) { match self.record.level.name { - PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), - PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), - PythonLogLevelEnum::Info => tracing::info!("{}", self.text), - PythonLogLevelEnum::Success => tracing::info!("{}", self.text), - PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), - PythonLogLevelEnum::Error => tracing::error!("{}", self.text), - PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), + PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text.trim_end()), + PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text.trim_end()), + PythonLogLevelEnum::Info => tracing::info!("{}", self.text.trim_end()), + PythonLogLevelEnum::Success => tracing::info!("{}", self.text.trim_end()), + PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text.trim_end()), + PythonLogLevelEnum::Error => tracing::error!("{}", self.text.trim_end()), + PythonLogLevelEnum::Critical => tracing::error!("{}", self.text.trim_end()), } } } @@ -996,6 +997,7 @@ fn spawn_shards( args: &Args, cuda_graphs: Vec, max_total_tokens: usize, + max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, shutdown_sender: mpsc::Sender<()>, @@ -1053,6 +1055,7 @@ fn spawn_shards( max_total_tokens, max_batch_size, otlp_endpoint, + max_log_level, status_sender, shutdown, shutdown_sender, @@ -1283,8 +1286,22 @@ fn main() -> Result<(), LauncherError> { let args: Args = Args::parse(); // Filter events with LOG_LEVEL - let env_filter = - EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; + let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO); if args.json_output { tracing_subscriber::fmt() @@ -1506,6 +1523,7 @@ fn main() -> Result<(), LauncherError> { &args, cuda_graphs, max_total_tokens, + max_log_level, shutdown.clone(), &shutdown_receiver, shutdown_sender, diff --git a/router/src/main.rs b/router/src/main.rs index 63347b78..b11c4526 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -20,7 +20,7 @@ use tokenizers::Tokenizer; use tower_http::cors::AllowOrigin; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{EnvFilter, Layer}; +use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer}; /// App Configuration #[derive(Parser, Debug)] @@ -454,8 +454,21 @@ fn init_logging(otlp_endpoint: Option, json_output: bool) { } // Filter events with LOG_LEVEL - let env_filter = - EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); + let varname = "LOG_LEVEL"; + let env_filter = if let Ok(log_level) = std::env::var(varname) { + // Override to avoid simple logs to be spammed with tokio level informations + let log_level = match &log_level[..] { + "warn" => "text_generation_launcher=warn,text_generation_router=warn", + "info" => "text_generation_launcher=info,text_generation_router=info", + "debug" => "text_generation_launcher=debug,text_generation_router=debug", + log_level => log_level, + }; + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .parse_lossy(log_level) + } else { + EnvFilter::new("info") + }; tracing_subscriber::registry() .with(env_filter) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 45ddd856..86d9b4c8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -17,6 +17,7 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.models.types import ( Batch, @@ -1187,6 +1188,10 @@ class FlashCausalLM(Model): next_token_texts = [] left = 0 + if n_accepted_ids > 1: + if RANK == 0: + logger.debug(f"Speculated ids {n_accepted_ids - 1}") + current_stopped = False for j in range(index, index + n_accepted_ids): # Generated token