diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a2c2b7fb..26cbf3b2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,7 +67,10 @@ jobs: run: | pip install pytest HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests - - name: Run Clippy + - name: Run Rust fmt + run: | + cargo fmt --check + - name: Run Rust clippy run: | cargo clippy - name: Run Rust tests diff --git a/launcher/src/main.rs b/launcher/src/main.rs index b59b0cb4..7b5f908a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -493,6 +493,7 @@ fn download_model(args: &Args, running: Arc) -> Result<(), LauncherE Ok(()) } +#[allow(clippy::too_many_arguments)] fn spawn_shards( num_shard: usize, args: &Args, @@ -515,11 +516,11 @@ fn spawn_shards( let shutdown = shutdown.clone(); let shutdown_sender = shutdown_sender.clone(); let otlp_endpoint = args.otlp_endpoint.clone(); - let quantize = args.quantize.clone(); - let master_port = args.master_port.clone(); - let disable_custom_kernels = args.disable_custom_kernels.clone(); - let watermark_gamma = args.watermark_gamma.clone(); - let watermark_delta = args.watermark_delta.clone(); + let quantize = args.quantize; + let master_port = args.master_port; + let disable_custom_kernels = args.disable_custom_kernels; + let watermark_gamma = args.watermark_gamma; + let watermark_delta = args.watermark_delta; thread::spawn(move || { shard_manager( model_id, @@ -559,12 +560,12 @@ fn spawn_shards( } Ok(ShardStatus::Failed((rank, err))) => { tracing::error!("Shard {} failed to start:\n{}", rank, err); - shutdown_shards(shutdown, &shutdown_receiver); + shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::ShardCannotStart); } Err(TryRecvError::Disconnected) => { tracing::error!("Shard status channel disconnected"); - shutdown_shards(shutdown, &shutdown_receiver); + shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::ShardDisconnected); } } @@ -666,7 +667,7 @@ fn spawn_webserver( tracing::error!("{}", err); } - shutdown_shards(shutdown, &shutdown_receiver); + shutdown_shards(shutdown, shutdown_receiver); return Err(LauncherError::WebserverCannotStart); } }; diff --git a/proto/generate.proto b/proto/generate.proto index ad47409e..894d7bc1 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -15,8 +15,13 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Health check + rpc Health (HealthRequest) returns (HealthResponse); } +message HealthRequest {} +message HealthResponse {} + /// Empty request message InfoRequest {} @@ -173,4 +178,4 @@ message DecodeResponse { repeated Generation generations = 1; /// Next batch (cached) optional Batch batch = 2; -} \ No newline at end of file +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7cadf430..bf1b6b58 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri}; use tracing::instrument; /// Text Generation Inference gRPC client -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Client { stub: TextGenerationServiceClient, } @@ -62,6 +62,14 @@ impl Client { Ok(response) } + /// Get model health + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let request = tonic::Request::new(HealthRequest {}).inject_context(); + let response = self.stub.health(request).await?.into_inner(); + Ok(response) + } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 6a001306..401082c5 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -6,6 +6,7 @@ mod pb; mod sharded_client; pub use client::Client; +pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 469d75f6..2f57a437 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,10 +1,11 @@ /// Multi shard Client use crate::Result; -use crate::{Batch, Client, Generation, Request, ShardInfo}; +use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; +#[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client pub struct ShardedClient { clients: Vec, @@ -48,6 +49,17 @@ impl ShardedClient { join_all(futures).await.pop().unwrap() } + /// GRPC health check + #[instrument(skip(self))] + pub async fn health(&mut self) -> Result { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| client.health()) + .collect(); + join_all(futures).await.pop().unwrap() + } + /// Clear the past generations cache #[instrument(skip(self))] pub async fn clear_cache(&mut self, batch_id: Option) -> Result<()> { diff --git a/router/src/health.rs b/router/src/health.rs new file mode 100644 index 00000000..02edf328 --- /dev/null +++ b/router/src/health.rs @@ -0,0 +1,62 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use text_generation_client::{ + Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, +}; + +#[derive(Clone, Debug)] +pub(crate) struct Health { + client: ShardedClient, + generation_health: Arc, +} + +impl Health { + pub(crate) fn new(client: ShardedClient, generation_health: Arc) -> Self { + Self { + client, + generation_health, + } + } + + pub(crate) async fn check(&mut self) -> bool { + if self.generation_health.load(Ordering::SeqCst) { + // Generation is healthy, we only check that the shards are answering gRPC calls + self.client.health().await.is_ok() + } else { + // Generation is unhealthy or have not sent any generation request yet + + // Dummy batch of 1 token and 1 generated token + let liveness_request = Request { + id: u64::MAX, + inputs: "liveness".to_string(), + truncate: 10, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + watermark: false, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + }; + let batch = Batch { + id: u64::MAX, + requests: vec![liveness_request], + size: 1, + max_tokens: 2, + }; + // Skips the queue + let value = self.client.prefill(batch).await.is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + value + } + } +} diff --git a/router/src/infer.rs b/router/src/infer.rs index 8b44ec86..313ec3e1 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -7,7 +7,10 @@ use flume::SendError; use futures::future::try_join_all; use futures::stream::StreamExt; use nohash_hasher::IntMap; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use text_generation_client::{ Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, }; @@ -36,6 +39,7 @@ struct Shared { } impl Infer { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( client: ShardedClient, validation: Validation, @@ -44,6 +48,7 @@ impl Infer { max_waiting_tokens: usize, max_concurrent_requests: usize, requires_padding: bool, + generation_health: Arc, ) -> Self { // Infer shared state let queue = Queue::new(requires_padding); @@ -59,6 +64,7 @@ impl Infer { max_waiting_tokens, queue.clone(), shared.clone(), + generation_health, )); // Inference limit with a semaphore @@ -240,6 +246,7 @@ async fn batching_task( max_waiting_tokens: usize, queue: Queue, shared: Arc, + generation_health: Arc, ) { // Infinite loop loop { @@ -252,7 +259,7 @@ async fn batching_task( while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_total_tokens).await { - let mut cached_batch = prefill(&mut client, batch, &mut entries) + let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health) .instrument(span) .await; let mut waiting_tokens = 1; @@ -301,9 +308,10 @@ async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache - let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) - .instrument(span) - .await; + let new_cached_batch = + prefill(&mut client, new_batch, &mut new_entries, &generation_health) + .instrument(span) + .await; // Reset waiting counter waiting_tokens = 1; // Extend current batch with the new batch @@ -327,7 +335,7 @@ async fn batching_task( entry.temp_span = Some(entry_batch_span); }); - cached_batch = decode(&mut client, batches, &mut entries) + cached_batch = decode(&mut client, batches, &mut entries, &generation_health) .instrument(next_batch_span) .await; waiting_tokens += 1; @@ -343,6 +351,7 @@ async fn prefill( client: &mut ShardedClient, batch: Batch, entries: &mut IntMap, + generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_id = batch.id; @@ -350,6 +359,8 @@ async fn prefill( match client.prefill(batch).await { Ok((generations, next_batch)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -362,6 +373,8 @@ async fn prefill( } // If we have an error, we discard the whole batch Err(err) => { + // Update health + generation_health.store(false, Ordering::SeqCst); let _ = client.clear_cache(Some(batch_id)).await; send_errors(err, entries); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); @@ -375,6 +388,7 @@ async fn decode( client: &mut ShardedClient, batches: Vec, entries: &mut IntMap, + generation_health: &Arc, ) -> Option { let start_time = Instant::now(); let batch_ids: Vec = batches.iter().map(|b| b.id).collect(); @@ -382,6 +396,8 @@ async fn decode( match client.decode(batches).await { Ok((generations, next_batch)) => { + // Update health + generation_health.store(true, Ordering::SeqCst); // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); @@ -394,6 +410,7 @@ async fn decode( } // If we have an error, we discard the whole batch Err(err) => { + generation_health.store(false, Ordering::SeqCst); for id in batch_ids { let _ = client.clear_cache(Some(id)).await; } diff --git a/router/src/lib.rs b/router/src/lib.rs index 85b13cfa..c2ff669b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1,3 +1,4 @@ +mod health; /// Text Generation Inference Webserver mod infer; mod queue; @@ -278,17 +279,21 @@ pub(crate) struct ErrorResponse { } #[cfg(test)] -mod tests{ +mod tests { use std::io::Write; use tokenizers::Tokenizer; - pub(crate) async fn get_tokenizer() -> Tokenizer{ - if !std::path::Path::new("tokenizer.json").exists(){ - let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); - let mut file = std::fs::File::create("tokenizer.json").unwrap(); + pub(crate) async fn get_tokenizer() -> Tokenizer { + if !std::path::Path::new("tokenizer.json").exists() { + let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json") + .await + .unwrap() + .bytes() + .await + .unwrap(); + let mut file = std::fs::File::create("tokenizer.json").unwrap(); file.write_all(&content).unwrap(); } Tokenizer::from_file("tokenizer.json").unwrap() } } - diff --git a/router/src/queue.rs b/router/src/queue.rs index d3f118d8..94851e1c 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -141,7 +141,6 @@ impl State { // Get the next batch fn next_batch(&mut self, min_size: Option, token_budget: u32) -> Option { - if self.entries.is_empty() { return None; } diff --git a/router/src/server.rs b/router/src/server.rs index 09b5c3ba..f25fa2b3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,3 +1,4 @@ +use crate::health::Health; /// HTTP Server logic use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; @@ -18,6 +19,8 @@ use futures::Stream; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use std::convert::Infallible; use std::net::SocketAddr; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use text_generation_client::{ShardInfo, ShardedClient}; use tokenizers::Tokenizer; use tokio::signal; @@ -82,36 +85,29 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( + get, + tag = "Text Generation Inference", + path = "/health", + responses( + (status = 200, description = "Everything is working fine"), + (status = 503, description = "Text generation inference is down", body = ErrorResponse, + example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), + ) +)] +#[instrument(skip(health))] /// Health check method -#[instrument(skip(infer))] -async fn health(infer: Extension) -> Result<(), (StatusCode, Json)> { - // TODO: while this is the best health check we can do, it is a bit on the heavy side and might - // be a bit too slow for a health check. - // What we should do instead is check if the gRPC channels are still healthy. - - // Send a small inference request - infer - .generate(GenerateRequest { - inputs: "liveness".to_string(), - parameters: GenerateParameters { - best_of: None, - temperature: None, - repetition_penalty: None, - top_k: None, - top_p: None, - typical_p: None, - do_sample: false, - max_new_tokens: 1, - return_full_text: None, - stop: Vec::new(), - truncate: None, - watermark: false, - details: false, - seed: None, - }, - }) - .await?; - Ok(()) +async fn health(mut health: Extension) -> Result<(), (StatusCode, Json)> { + match health.check().await { + true => Ok(()), + false => Err(( + StatusCode::SERVICE_UNAVAILABLE, + Json(ErrorResponse { + error: "unhealthy".to_string(), + error_type: "healthcheck".to_string(), + }), + )), + } } /// Generate tokens @@ -555,6 +551,8 @@ pub async fn run( max_input_length, max_total_tokens, ); + let generation_health = Arc::new(AtomicBool::new(false)); + let health_ext = Health::new(client.clone(), generation_health.clone()); let infer = Infer::new( client, validation, @@ -563,6 +561,7 @@ pub async fn run( max_waiting_tokens, max_concurrent_requests, shard_info.requires_padding, + generation_health, ); // Duration buckets @@ -657,6 +656,7 @@ pub async fn run( // Prometheus metrics route .route("/metrics", get(metrics)) .layer(Extension(info)) + .layer(Extension(health_ext)) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) .layer(Extension(prom_handle)) @@ -741,4 +741,3 @@ impl From for Event { .unwrap() } } - diff --git a/router/src/validation.rs b/router/src/validation.rs index ff2fe89d..cbb0d9cd 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -380,111 +380,154 @@ pub enum ValidationError { } #[cfg(test)] -mod tests{ +mod tests { use super::*; use crate::default_parameters; use crate::tests::get_tokenizer; #[tokio::test] - async fn test_validation_max_new_tokens(){ + async fn test_validation_max_new_tokens() { let tokenizer = None; let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); let max_new_tokens = 10; - match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + match validation + .validate_input("Hello".to_string(), None, max_new_tokens) + .await + { Err(ValidationError::MaxNewTokens(1, 10)) => (), - _ => panic!("Unexpected not max new tokens") + _ => panic!("Unexpected not max new tokens"), } } #[tokio::test] - async fn test_validation_input_length(){ + async fn test_validation_input_length() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); let max_new_tokens = 10; - match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ + match validation + .validate_input("Hello".to_string(), None, max_new_tokens) + .await + { Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), - _ => panic!("Unexpected not max new tokens") + _ => panic!("Unexpected not max new tokens"), } } #[tokio::test] - async fn test_validation_best_of_sampling(){ + async fn test_validation_best_of_sampling() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - best_of: Some(2), - do_sample: false, - ..default_parameters() - } - }).await{ + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + best_of: Some(2), + do_sample: false, + ..default_parameters() + }, + }) + .await + { Err(ValidationError::BestOfSampling) => (), - _ => panic!("Unexpected not best of sampling") + _ => panic!("Unexpected not best of sampling"), } - } #[tokio::test] - async fn test_validation_top_p(){ + async fn test_validation_top_p() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; let max_input_length = 4; let max_total_tokens = 5; let workers = 1; - let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - top_p: Some(1.0), - ..default_parameters() - } - }).await{ + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequence, + max_input_length, + max_total_tokens, + ); + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_p: Some(1.0), + ..default_parameters() + }, + }) + .await + { Err(ValidationError::TopP) => (), - _ => panic!("Unexpected top_p") + _ => panic!("Unexpected top_p"), } - match validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - top_p: Some(0.99), - max_new_tokens: 1, - ..default_parameters() - } - }).await{ + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_p: Some(0.99), + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + { Ok(_) => (), - _ => panic!("Unexpected top_p error") + _ => panic!("Unexpected top_p error"), } - let valid_request = validation.validate(GenerateRequest{ - inputs: "Hello".to_string(), - parameters: GenerateParameters{ - top_p: None, - max_new_tokens: 1, - ..default_parameters() - } - }).await.unwrap(); + let valid_request = validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_p: None, + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + .unwrap(); // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. assert_eq!(valid_request.parameters.top_p, 1.0); - - } } diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index ddb7aae9..70f08ed7 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -29,6 +29,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): async def Info(self, request, context): return self.model.info + async def Health(self, request, context): + if self.model.device.type == "cuda": + torch.zeros((2, 2)).cuda() + return generate_pb2.HealthResponse() + async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)