feat(router): new healthcheck that skips the queue (#244)

Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
This commit is contained in:
Nicolas Patry 2023-04-26 20:23:54 +02:00 committed by GitHub
parent c4fb09f2ae
commit db2b4e0754
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 265 additions and 105 deletions

View File

@ -67,7 +67,10 @@ jobs:
run: | run: |
pip install pytest pip install pytest
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
- name: Run Clippy - name: Run Rust fmt
run: |
cargo fmt --check
- name: Run Rust clippy
run: | run: |
cargo clippy cargo clippy
- name: Run Rust tests - name: Run Rust tests

View File

@ -493,6 +493,7 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
Ok(()) Ok(())
} }
#[allow(clippy::too_many_arguments)]
fn spawn_shards( fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
@ -515,11 +516,11 @@ fn spawn_shards(
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = args.otlp_endpoint.clone(); let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize.clone(); let quantize = args.quantize;
let master_port = args.master_port.clone(); let master_port = args.master_port;
let disable_custom_kernels = args.disable_custom_kernels.clone(); let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma.clone(); let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta.clone(); let watermark_delta = args.watermark_delta;
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
@ -559,12 +560,12 @@ fn spawn_shards(
} }
Ok(ShardStatus::Failed((rank, err))) => { Ok(ShardStatus::Failed((rank, err))) => {
tracing::error!("Shard {} failed to start:\n{}", rank, err); tracing::error!("Shard {} failed to start:\n{}", rank, err);
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardCannotStart); return Err(LauncherError::ShardCannotStart);
} }
Err(TryRecvError::Disconnected) => { Err(TryRecvError::Disconnected) => {
tracing::error!("Shard status channel disconnected"); tracing::error!("Shard status channel disconnected");
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::ShardDisconnected); return Err(LauncherError::ShardDisconnected);
} }
} }
@ -666,7 +667,7 @@ fn spawn_webserver(
tracing::error!("{}", err); tracing::error!("{}", err);
} }
shutdown_shards(shutdown, &shutdown_receiver); shutdown_shards(shutdown, shutdown_receiver);
return Err(LauncherError::WebserverCannotStart); return Err(LauncherError::WebserverCannotStart);
} }
}; };

View File

@ -15,8 +15,13 @@ service TextGenerationService {
rpc Prefill (PrefillRequest) returns (PrefillResponse); rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Decode token for a list of prefilled batches /// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse); rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
} }
message HealthRequest {}
message HealthResponse {}
/// Empty request /// Empty request
message InfoRequest {} message InfoRequest {}
@ -173,4 +178,4 @@ message DecodeResponse {
repeated Generation generations = 1; repeated Generation generations = 1;
/// Next batch (cached) /// Next batch (cached)
optional Batch batch = 2; optional Batch batch = 2;
} }

View File

@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
/// Text Generation Inference gRPC client /// Text Generation Inference gRPC client
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct Client { pub struct Client {
stub: TextGenerationServiceClient<Channel>, stub: TextGenerationServiceClient<Channel>,
} }
@ -62,6 +62,14 @@ impl Client {
Ok(response) Ok(response)
} }
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {

View File

@ -6,6 +6,7 @@ mod pb;
mod sharded_client; mod sharded_client;
pub use client::Client; pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{ pub use pb::generate::v1::{
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens, Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,

View File

@ -1,10 +1,11 @@
/// Multi shard Client /// Multi shard Client
use crate::Result; use crate::Result;
use crate::{Batch, Client, Generation, Request, ShardInfo}; use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client /// Text Generation Inference gRPC multi client
pub struct ShardedClient { pub struct ShardedClient {
clients: Vec<Client>, clients: Vec<Client>,
@ -48,6 +49,17 @@ impl ShardedClient {
join_all(futures).await.pop().unwrap() join_all(futures).await.pop().unwrap()
} }
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {

62
router/src/health.rs Normal file
View File

@ -0,0 +1,62 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
#[derive(Clone, Debug)]
pub(crate) struct Health {
client: ShardedClient,
generation_health: Arc<AtomicBool>,
}
impl Health {
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}
}

View File

@ -7,7 +7,10 @@ use flume::SendError;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::{ use text_generation_client::{
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
@ -36,6 +39,7 @@ struct Shared {
} }
impl Infer { impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
client: ShardedClient, client: ShardedClient,
validation: Validation, validation: Validation,
@ -44,6 +48,7 @@ impl Infer {
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_concurrent_requests: usize, max_concurrent_requests: usize,
requires_padding: bool, requires_padding: bool,
generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
// Infer shared state // Infer shared state
let queue = Queue::new(requires_padding); let queue = Queue::new(requires_padding);
@ -59,6 +64,7 @@ impl Infer {
max_waiting_tokens, max_waiting_tokens,
queue.clone(), queue.clone(),
shared.clone(), shared.clone(),
generation_health,
)); ));
// Inference limit with a semaphore // Inference limit with a semaphore
@ -240,6 +246,7 @@ async fn batching_task(
max_waiting_tokens: usize, max_waiting_tokens: usize,
queue: Queue, queue: Queue,
shared: Arc<Shared>, shared: Arc<Shared>,
generation_health: Arc<AtomicBool>,
) { ) {
// Infinite loop // Infinite loop
loop { loop {
@ -252,7 +259,7 @@ async fn batching_task(
while let Some((mut entries, batch, span)) = while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await queue.next_batch(None, max_batch_total_tokens).await
{ {
let mut cached_batch = prefill(&mut client, batch, &mut entries) let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span) .instrument(span)
.await; .await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
@ -301,9 +308,10 @@ async fn batching_task(
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries) let new_cached_batch =
.instrument(span) prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.await; .instrument(span)
.await;
// Reset waiting counter // Reset waiting counter
waiting_tokens = 1; waiting_tokens = 1;
// Extend current batch with the new batch // Extend current batch with the new batch
@ -327,7 +335,7 @@ async fn batching_task(
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
}); });
cached_batch = decode(&mut client, batches, &mut entries) cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
.instrument(next_batch_span) .instrument(next_batch_span)
.await; .await;
waiting_tokens += 1; waiting_tokens += 1;
@ -343,6 +351,7 @@ async fn prefill(
client: &mut ShardedClient, client: &mut ShardedClient,
batch: Batch, batch: Batch,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> { ) -> Option<Batch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_id = batch.id; let batch_id = batch.id;
@ -350,6 +359,8 @@ async fn prefill(
match client.prefill(batch).await { match client.prefill(batch).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -362,6 +373,8 @@ async fn prefill(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await; let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries); send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill"); metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
@ -375,6 +388,7 @@ async fn decode(
client: &mut ShardedClient, client: &mut ShardedClient,
batches: Vec<Batch>, batches: Vec<Batch>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<Batch> { ) -> Option<Batch> {
let start_time = Instant::now(); let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect(); let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
@ -382,6 +396,8 @@ async fn decode(
match client.decode(batches).await { match client.decode(batches).await {
Ok((generations, next_batch)) => { Ok((generations, next_batch)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
// Send generated tokens and filter stopped entries // Send generated tokens and filter stopped entries
filter_send_generations(generations, entries); filter_send_generations(generations, entries);
@ -394,6 +410,7 @@ async fn decode(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids { for id in batch_ids {
let _ = client.clear_cache(Some(id)).await; let _ = client.clear_cache(Some(id)).await;
} }

View File

@ -1,3 +1,4 @@
mod health;
/// Text Generation Inference Webserver /// Text Generation Inference Webserver
mod infer; mod infer;
mod queue; mod queue;
@ -278,17 +279,21 @@ pub(crate) struct ErrorResponse {
} }
#[cfg(test)] #[cfg(test)]
mod tests{ mod tests {
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer{ pub(crate) async fn get_tokenizer() -> Tokenizer {
if !std::path::Path::new("tokenizer.json").exists(){ if !std::path::Path::new("tokenizer.json").exists() {
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap(); let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
let mut file = std::fs::File::create("tokenizer.json").unwrap(); .await
.unwrap()
.bytes()
.await
.unwrap();
let mut file = std::fs::File::create("tokenizer.json").unwrap();
file.write_all(&content).unwrap(); file.write_all(&content).unwrap();
} }
Tokenizer::from_file("tokenizer.json").unwrap() Tokenizer::from_file("tokenizer.json").unwrap()
} }
} }

View File

@ -141,7 +141,6 @@ impl State {
// Get the next batch // Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> { fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
if self.entries.is_empty() { if self.entries.is_empty() {
return None; return None;
} }

View File

@ -1,3 +1,4 @@
use crate::health::Health;
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
@ -18,6 +19,8 @@ use futures::Stream;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
@ -82,36 +85,29 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
Json(info.0) Json(info.0)
} }
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/health",
responses(
(status = 200, description = "Everything is working fine"),
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
)]
#[instrument(skip(health))]
/// Health check method /// Health check method
#[instrument(skip(infer))] async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { match health.check().await {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might true => Ok(()),
// be a bit too slow for a health check. false => Err((
// What we should do instead is check if the gRPC channels are still healthy. StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
// Send a small inference request error: "unhealthy".to_string(),
infer error_type: "healthcheck".to_string(),
.generate(GenerateRequest { }),
inputs: "liveness".to_string(), )),
parameters: GenerateParameters { }
best_of: None,
temperature: None,
repetition_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
max_new_tokens: 1,
return_full_text: None,
stop: Vec::new(),
truncate: None,
watermark: false,
details: false,
seed: None,
},
})
.await?;
Ok(())
} }
/// Generate tokens /// Generate tokens
@ -555,6 +551,8 @@ pub async fn run(
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let infer = Infer::new( let infer = Infer::new(
client, client,
validation, validation,
@ -563,6 +561,7 @@ pub async fn run(
max_waiting_tokens, max_waiting_tokens,
max_concurrent_requests, max_concurrent_requests,
shard_info.requires_padding, shard_info.requires_padding,
generation_health,
); );
// Duration buckets // Duration buckets
@ -657,6 +656,7 @@ pub async fn run(
// Prometheus metrics route // Prometheus metrics route
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle))
@ -741,4 +741,3 @@ impl From<InferError> for Event {
.unwrap() .unwrap()
} }
} }

View File

@ -380,111 +380,154 @@ pub enum ValidationError {
} }
#[cfg(test)] #[cfg(test)]
mod tests{ mod tests {
use super::*; use super::*;
use crate::default_parameters; use crate::default_parameters;
use crate::tests::get_tokenizer; use crate::tests::get_tokenizer;
#[tokio::test] #[tokio::test]
async fn test_validation_max_new_tokens(){ async fn test_validation_max_new_tokens() {
let tokenizer = None; let tokenizer = None;
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);
let max_new_tokens = 10; let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxNewTokens(1, 10)) => (), Err(ValidationError::MaxNewTokens(1, 10)) => (),
_ => panic!("Unexpected not max new tokens") _ => panic!("Unexpected not max new tokens"),
} }
} }
#[tokio::test] #[tokio::test]
async fn test_validation_input_length(){ async fn test_validation_input_length() {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequence,
max_input_length,
max_total_tokens,
);
let max_new_tokens = 10; let max_new_tokens = 10;
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{ match validation
.validate_input("Hello".to_string(), None, max_new_tokens)
.await
{
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens") _ => panic!("Unexpected not max new tokens"),
} }
} }
#[tokio::test] #[tokio::test]
async fn test_validation_best_of_sampling(){ async fn test_validation_best_of_sampling() {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
match validation.validate(GenerateRequest{ workers,
inputs: "Hello".to_string(), tokenizer,
parameters: GenerateParameters{ max_best_of,
best_of: Some(2), max_stop_sequence,
do_sample: false, max_input_length,
..default_parameters() max_total_tokens,
} );
}).await{ match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
best_of: Some(2),
do_sample: false,
..default_parameters()
},
})
.await
{
Err(ValidationError::BestOfSampling) => (), Err(ValidationError::BestOfSampling) => (),
_ => panic!("Unexpected not best of sampling") _ => panic!("Unexpected not best of sampling"),
} }
} }
#[tokio::test] #[tokio::test]
async fn test_validation_top_p(){ async fn test_validation_top_p() {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_input_length = 4;
let max_total_tokens = 5; let max_total_tokens = 5;
let workers = 1; let workers = 1;
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens); let validation = Validation::new(
match validation.validate(GenerateRequest{ workers,
inputs: "Hello".to_string(), tokenizer,
parameters: GenerateParameters{ max_best_of,
top_p: Some(1.0), max_stop_sequence,
..default_parameters() max_input_length,
} max_total_tokens,
}).await{ );
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: Some(1.0),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopP) => (), Err(ValidationError::TopP) => (),
_ => panic!("Unexpected top_p") _ => panic!("Unexpected top_p"),
} }
match validation.validate(GenerateRequest{ match validation
inputs: "Hello".to_string(), .validate(GenerateRequest {
parameters: GenerateParameters{ inputs: "Hello".to_string(),
top_p: Some(0.99), parameters: GenerateParameters {
max_new_tokens: 1, top_p: Some(0.99),
..default_parameters() max_new_tokens: 1,
} ..default_parameters()
}).await{ },
})
.await
{
Ok(_) => (), Ok(_) => (),
_ => panic!("Unexpected top_p error") _ => panic!("Unexpected top_p error"),
} }
let valid_request = validation.validate(GenerateRequest{ let valid_request = validation
inputs: "Hello".to_string(), .validate(GenerateRequest {
parameters: GenerateParameters{ inputs: "Hello".to_string(),
top_p: None, parameters: GenerateParameters {
max_new_tokens: 1, top_p: None,
..default_parameters() max_new_tokens: 1,
} ..default_parameters()
}).await.unwrap(); },
})
.await
.unwrap();
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value. // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0); assert_eq!(valid_request.parameters.top_p, 1.0);
} }
} }

View File

@ -29,6 +29,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info
async def Health(self, request, context):
if self.model.device.type == "cuda":
torch.zeros((2, 2)).cuda()
return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context): async def ServiceDiscovery(self, request, context):
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)