feat(router): new healthcheck that skips the queue (#244)
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
This commit is contained in:
parent
c4fb09f2ae
commit
db2b4e0754
|
@ -67,7 +67,10 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||||
- name: Run Clippy
|
- name: Run Rust fmt
|
||||||
|
run: |
|
||||||
|
cargo fmt --check
|
||||||
|
- name: Run Rust clippy
|
||||||
run: |
|
run: |
|
||||||
cargo clippy
|
cargo clippy
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
|
|
|
@ -493,6 +493,7 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn spawn_shards(
|
fn spawn_shards(
|
||||||
num_shard: usize,
|
num_shard: usize,
|
||||||
args: &Args,
|
args: &Args,
|
||||||
|
@ -515,11 +516,11 @@ fn spawn_shards(
|
||||||
let shutdown = shutdown.clone();
|
let shutdown = shutdown.clone();
|
||||||
let shutdown_sender = shutdown_sender.clone();
|
let shutdown_sender = shutdown_sender.clone();
|
||||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||||
let quantize = args.quantize.clone();
|
let quantize = args.quantize;
|
||||||
let master_port = args.master_port.clone();
|
let master_port = args.master_port;
|
||||||
let disable_custom_kernels = args.disable_custom_kernels.clone();
|
let disable_custom_kernels = args.disable_custom_kernels;
|
||||||
let watermark_gamma = args.watermark_gamma.clone();
|
let watermark_gamma = args.watermark_gamma;
|
||||||
let watermark_delta = args.watermark_delta.clone();
|
let watermark_delta = args.watermark_delta;
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
shard_manager(
|
shard_manager(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -559,12 +560,12 @@ fn spawn_shards(
|
||||||
}
|
}
|
||||||
Ok(ShardStatus::Failed((rank, err))) => {
|
Ok(ShardStatus::Failed((rank, err))) => {
|
||||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, shutdown_receiver);
|
||||||
return Err(LauncherError::ShardCannotStart);
|
return Err(LauncherError::ShardCannotStart);
|
||||||
}
|
}
|
||||||
Err(TryRecvError::Disconnected) => {
|
Err(TryRecvError::Disconnected) => {
|
||||||
tracing::error!("Shard status channel disconnected");
|
tracing::error!("Shard status channel disconnected");
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, shutdown_receiver);
|
||||||
return Err(LauncherError::ShardDisconnected);
|
return Err(LauncherError::ShardDisconnected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -666,7 +667,7 @@ fn spawn_webserver(
|
||||||
tracing::error!("{}", err);
|
tracing::error!("{}", err);
|
||||||
}
|
}
|
||||||
|
|
||||||
shutdown_shards(shutdown, &shutdown_receiver);
|
shutdown_shards(shutdown, shutdown_receiver);
|
||||||
return Err(LauncherError::WebserverCannotStart);
|
return Err(LauncherError::WebserverCannotStart);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -15,8 +15,13 @@ service TextGenerationService {
|
||||||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||||
/// Decode token for a list of prefilled batches
|
/// Decode token for a list of prefilled batches
|
||||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||||
|
/// Health check
|
||||||
|
rpc Health (HealthRequest) returns (HealthResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message HealthRequest {}
|
||||||
|
message HealthResponse {}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
message InfoRequest {}
|
message InfoRequest {}
|
||||||
|
|
||||||
|
@ -173,4 +178,4 @@ message DecodeResponse {
|
||||||
repeated Generation generations = 1;
|
repeated Generation generations = 1;
|
||||||
/// Next batch (cached)
|
/// Next batch (cached)
|
||||||
optional Batch batch = 2;
|
optional Batch batch = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
/// Text Generation Inference gRPC client
|
/// Text Generation Inference gRPC client
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
stub: TextGenerationServiceClient<Channel>,
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
}
|
}
|
||||||
|
@ -62,6 +62,14 @@ impl Client {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
|
|
@ -6,6 +6,7 @@ mod pb;
|
||||||
mod sharded_client;
|
mod sharded_client;
|
||||||
|
|
||||||
pub use client::Client;
|
pub use client::Client;
|
||||||
|
pub use pb::generate::v1::HealthResponse;
|
||||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v1::{
|
pub use pb::generate::v1::{
|
||||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::{Batch, Client, Generation, Request, ShardInfo};
|
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||||
use futures::future::join_all;
|
use futures::future::join_all;
|
||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
pub struct ShardedClient {
|
pub struct ShardedClient {
|
||||||
clients: Vec<Client>,
|
clients: Vec<Client>,
|
||||||
|
@ -48,6 +49,17 @@ impl ShardedClient {
|
||||||
join_all(futures).await.pop().unwrap()
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Clear the past generations cache
|
/// Clear the past generations cache
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::{
|
||||||
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct Health {
|
||||||
|
client: ShardedClient,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Health {
|
||||||
|
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
generation_health,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn check(&mut self) -> bool {
|
||||||
|
if self.generation_health.load(Ordering::SeqCst) {
|
||||||
|
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||||
|
self.client.health().await.is_ok()
|
||||||
|
} else {
|
||||||
|
// Generation is unhealthy or have not sent any generation request yet
|
||||||
|
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
truncate: 10,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
watermark: false,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
};
|
||||||
|
// Skips the queue
|
||||||
|
let value = self.client.prefill(batch).await.is_ok();
|
||||||
|
// Update generation health
|
||||||
|
self.generation_health.store(value, Ordering::SeqCst);
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,7 +7,10 @@ use flume::SendError;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
|
@ -36,6 +39,7 @@ struct Shared {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Infer {
|
impl Infer {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
validation: Validation,
|
validation: Validation,
|
||||||
|
@ -44,6 +48,7 @@ impl Infer {
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
// Infer shared state
|
||||||
let queue = Queue::new(requires_padding);
|
let queue = Queue::new(requires_padding);
|
||||||
|
@ -59,6 +64,7 @@ impl Infer {
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
shared.clone(),
|
shared.clone(),
|
||||||
|
generation_health,
|
||||||
));
|
));
|
||||||
|
|
||||||
// Inference limit with a semaphore
|
// Inference limit with a semaphore
|
||||||
|
@ -240,6 +246,7 @@ async fn batching_task(
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
shared: Arc<Shared>,
|
shared: Arc<Shared>,
|
||||||
|
generation_health: Arc<AtomicBool>,
|
||||||
) {
|
) {
|
||||||
// Infinite loop
|
// Infinite loop
|
||||||
loop {
|
loop {
|
||||||
|
@ -252,7 +259,7 @@ async fn batching_task(
|
||||||
while let Some((mut entries, batch, span)) =
|
while let Some((mut entries, batch, span)) =
|
||||||
queue.next_batch(None, max_batch_total_tokens).await
|
queue.next_batch(None, max_batch_total_tokens).await
|
||||||
{
|
{
|
||||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
let mut waiting_tokens = 1;
|
let mut waiting_tokens = 1;
|
||||||
|
@ -301,9 +308,10 @@ async fn batching_task(
|
||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
let new_cached_batch =
|
||||||
.instrument(span)
|
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||||
.await;
|
.instrument(span)
|
||||||
|
.await;
|
||||||
// Reset waiting counter
|
// Reset waiting counter
|
||||||
waiting_tokens = 1;
|
waiting_tokens = 1;
|
||||||
// Extend current batch with the new batch
|
// Extend current batch with the new batch
|
||||||
|
@ -327,7 +335,7 @@ async fn batching_task(
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
});
|
});
|
||||||
|
|
||||||
cached_batch = decode(&mut client, batches, &mut entries)
|
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||||
.instrument(next_batch_span)
|
.instrument(next_batch_span)
|
||||||
.await;
|
.await;
|
||||||
waiting_tokens += 1;
|
waiting_tokens += 1;
|
||||||
|
@ -343,6 +351,7 @@ async fn prefill(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batch: Batch,
|
batch: Batch,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_id = batch.id;
|
let batch_id = batch.id;
|
||||||
|
@ -350,6 +359,8 @@ async fn prefill(
|
||||||
|
|
||||||
match client.prefill(batch).await {
|
match client.prefill(batch).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
|
// Update health
|
||||||
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
@ -362,6 +373,8 @@ async fn prefill(
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
// Update health
|
||||||
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
let _ = client.clear_cache(Some(batch_id)).await;
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
send_errors(err, entries);
|
send_errors(err, entries);
|
||||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||||
|
@ -375,6 +388,7 @@ async fn decode(
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
batches: Vec<Batch>,
|
batches: Vec<Batch>,
|
||||||
entries: &mut IntMap<u64, Entry>,
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
generation_health: &Arc<AtomicBool>,
|
||||||
) -> Option<Batch> {
|
) -> Option<Batch> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
|
@ -382,6 +396,8 @@ async fn decode(
|
||||||
|
|
||||||
match client.decode(batches).await {
|
match client.decode(batches).await {
|
||||||
Ok((generations, next_batch)) => {
|
Ok((generations, next_batch)) => {
|
||||||
|
// Update health
|
||||||
|
generation_health.store(true, Ordering::SeqCst);
|
||||||
// Send generated tokens and filter stopped entries
|
// Send generated tokens and filter stopped entries
|
||||||
filter_send_generations(generations, entries);
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
@ -394,6 +410,7 @@ async fn decode(
|
||||||
}
|
}
|
||||||
// If we have an error, we discard the whole batch
|
// If we have an error, we discard the whole batch
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
generation_health.store(false, Ordering::SeqCst);
|
||||||
for id in batch_ids {
|
for id in batch_ids {
|
||||||
let _ = client.clear_cache(Some(id)).await;
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
mod health;
|
||||||
/// Text Generation Inference Webserver
|
/// Text Generation Inference Webserver
|
||||||
mod infer;
|
mod infer;
|
||||||
mod queue;
|
mod queue;
|
||||||
|
@ -278,17 +279,21 @@ pub(crate) struct ErrorResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests{
|
mod tests {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
pub(crate) async fn get_tokenizer() -> Tokenizer{
|
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||||
if !std::path::Path::new("tokenizer.json").exists(){
|
if !std::path::Path::new("tokenizer.json").exists() {
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||||
file.write_all(&content).unwrap();
|
file.write_all(&content).unwrap();
|
||||||
}
|
}
|
||||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,6 @@ impl State {
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
||||||
|
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::health::Health;
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
|
@ -18,6 +19,8 @@ use futures::Stream;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
@ -82,36 +85,29 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||||
Json(info.0)
|
Json(info.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/health",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Everything is working fine"),
|
||||||
|
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip(health))]
|
||||||
/// Health check method
|
/// Health check method
|
||||||
#[instrument(skip(infer))]
|
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
match health.check().await {
|
||||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
true => Ok(()),
|
||||||
// be a bit too slow for a health check.
|
false => Err((
|
||||||
// What we should do instead is check if the gRPC channels are still healthy.
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(ErrorResponse {
|
||||||
// Send a small inference request
|
error: "unhealthy".to_string(),
|
||||||
infer
|
error_type: "healthcheck".to_string(),
|
||||||
.generate(GenerateRequest {
|
}),
|
||||||
inputs: "liveness".to_string(),
|
)),
|
||||||
parameters: GenerateParameters {
|
}
|
||||||
best_of: None,
|
|
||||||
temperature: None,
|
|
||||||
repetition_penalty: None,
|
|
||||||
top_k: None,
|
|
||||||
top_p: None,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample: false,
|
|
||||||
max_new_tokens: 1,
|
|
||||||
return_full_text: None,
|
|
||||||
stop: Vec::new(),
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: false,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
|
@ -555,6 +551,8 @@ pub async fn run(
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
);
|
);
|
||||||
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
|
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||||
let infer = Infer::new(
|
let infer = Infer::new(
|
||||||
client,
|
client,
|
||||||
validation,
|
validation,
|
||||||
|
@ -563,6 +561,7 @@ pub async fn run(
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
|
generation_health,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Duration buckets
|
// Duration buckets
|
||||||
|
@ -657,6 +656,7 @@ pub async fn run(
|
||||||
// Prometheus metrics route
|
// Prometheus metrics route
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(info))
|
.layer(Extension(info))
|
||||||
|
.layer(Extension(health_ext))
|
||||||
.layer(Extension(compat_return_full_text))
|
.layer(Extension(compat_return_full_text))
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
|
@ -741,4 +741,3 @@ impl From<InferError> for Event {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -380,111 +380,154 @@ pub enum ValidationError {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests{
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::default_parameters;
|
use crate::default_parameters;
|
||||||
use crate::tests::get_tokenizer;
|
use crate::tests::get_tokenizer;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_max_new_tokens(){
|
async fn test_validation_max_new_tokens() {
|
||||||
let tokenizer = None;
|
let tokenizer = None;
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
match validation
|
||||||
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_input_length(){
|
async fn test_validation_input_length() {
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
match validation
|
||||||
|
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||||
_ => panic!("Unexpected not max new tokens")
|
_ => panic!("Unexpected not max new tokens"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_best_of_sampling(){
|
async fn test_validation_best_of_sampling() {
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
match validation.validate(GenerateRequest{
|
workers,
|
||||||
inputs: "Hello".to_string(),
|
tokenizer,
|
||||||
parameters: GenerateParameters{
|
max_best_of,
|
||||||
best_of: Some(2),
|
max_stop_sequence,
|
||||||
do_sample: false,
|
max_input_length,
|
||||||
..default_parameters()
|
max_total_tokens,
|
||||||
}
|
);
|
||||||
}).await{
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: Some(2),
|
||||||
|
do_sample: false,
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::BestOfSampling) => (),
|
Err(ValidationError::BestOfSampling) => (),
|
||||||
_ => panic!("Unexpected not best of sampling")
|
_ => panic!("Unexpected not best of sampling"),
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_validation_top_p(){
|
async fn test_validation_top_p() {
|
||||||
let tokenizer = Some(get_tokenizer().await);
|
let tokenizer = Some(get_tokenizer().await);
|
||||||
let max_best_of = 2;
|
let max_best_of = 2;
|
||||||
let max_stop_sequence = 3;
|
let max_stop_sequence = 3;
|
||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
match validation.validate(GenerateRequest{
|
workers,
|
||||||
inputs: "Hello".to_string(),
|
tokenizer,
|
||||||
parameters: GenerateParameters{
|
max_best_of,
|
||||||
top_p: Some(1.0),
|
max_stop_sequence,
|
||||||
..default_parameters()
|
max_input_length,
|
||||||
}
|
max_total_tokens,
|
||||||
}).await{
|
);
|
||||||
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
|
inputs: "Hello".to_string(),
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
top_p: Some(1.0),
|
||||||
|
..default_parameters()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::TopP) => (),
|
Err(ValidationError::TopP) => (),
|
||||||
_ => panic!("Unexpected top_p")
|
_ => panic!("Unexpected top_p"),
|
||||||
}
|
}
|
||||||
|
|
||||||
match validation.validate(GenerateRequest{
|
match validation
|
||||||
inputs: "Hello".to_string(),
|
.validate(GenerateRequest {
|
||||||
parameters: GenerateParameters{
|
inputs: "Hello".to_string(),
|
||||||
top_p: Some(0.99),
|
parameters: GenerateParameters {
|
||||||
max_new_tokens: 1,
|
top_p: Some(0.99),
|
||||||
..default_parameters()
|
max_new_tokens: 1,
|
||||||
}
|
..default_parameters()
|
||||||
}).await{
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
_ => panic!("Unexpected top_p error")
|
_ => panic!("Unexpected top_p error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
let valid_request = validation.validate(GenerateRequest{
|
let valid_request = validation
|
||||||
inputs: "Hello".to_string(),
|
.validate(GenerateRequest {
|
||||||
parameters: GenerateParameters{
|
inputs: "Hello".to_string(),
|
||||||
top_p: None,
|
parameters: GenerateParameters {
|
||||||
max_new_tokens: 1,
|
top_p: None,
|
||||||
..default_parameters()
|
max_new_tokens: 1,
|
||||||
}
|
..default_parameters()
|
||||||
}).await.unwrap();
|
},
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
|
||||||
|
async def Health(self, request, context):
|
||||||
|
if self.model.device.type == "cuda":
|
||||||
|
torch.zeros((2, 2)).cuda()
|
||||||
|
return generate_pb2.HealthResponse()
|
||||||
|
|
||||||
async def ServiceDiscovery(self, request, context):
|
async def ServiceDiscovery(self, request, context):
|
||||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue