feat(router): new healthcheck that skips the queue (#244)
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Co-authored-by: OlivierDehaene <olivier@huggingface.co>
This commit is contained in:
parent
c4fb09f2ae
commit
db2b4e0754
|
@ -67,7 +67,10 @@ jobs:
|
|||
run: |
|
||||
pip install pytest
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
|
||||
- name: Run Clippy
|
||||
- name: Run Rust fmt
|
||||
run: |
|
||||
cargo fmt --check
|
||||
- name: Run Rust clippy
|
||||
run: |
|
||||
cargo clippy
|
||||
- name: Run Rust tests
|
||||
|
|
|
@ -493,6 +493,7 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
|
@ -515,11 +516,11 @@ fn spawn_shards(
|
|||
let shutdown = shutdown.clone();
|
||||
let shutdown_sender = shutdown_sender.clone();
|
||||
let otlp_endpoint = args.otlp_endpoint.clone();
|
||||
let quantize = args.quantize.clone();
|
||||
let master_port = args.master_port.clone();
|
||||
let disable_custom_kernels = args.disable_custom_kernels.clone();
|
||||
let watermark_gamma = args.watermark_gamma.clone();
|
||||
let watermark_delta = args.watermark_delta.clone();
|
||||
let quantize = args.quantize;
|
||||
let master_port = args.master_port;
|
||||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
|
@ -559,12 +560,12 @@ fn spawn_shards(
|
|||
}
|
||||
Ok(ShardStatus::Failed((rank, err))) => {
|
||||
tracing::error!("Shard {} failed to start:\n{}", rank, err);
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
shutdown_shards(shutdown, shutdown_receiver);
|
||||
return Err(LauncherError::ShardCannotStart);
|
||||
}
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
tracing::error!("Shard status channel disconnected");
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
shutdown_shards(shutdown, shutdown_receiver);
|
||||
return Err(LauncherError::ShardDisconnected);
|
||||
}
|
||||
}
|
||||
|
@ -666,7 +667,7 @@ fn spawn_webserver(
|
|||
tracing::error!("{}", err);
|
||||
}
|
||||
|
||||
shutdown_shards(shutdown, &shutdown_receiver);
|
||||
shutdown_shards(shutdown, shutdown_receiver);
|
||||
return Err(LauncherError::WebserverCannotStart);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -15,8 +15,13 @@ service TextGenerationService {
|
|||
rpc Prefill (PrefillRequest) returns (PrefillResponse);
|
||||
/// Decode token for a list of prefilled batches
|
||||
rpc Decode (DecodeRequest) returns (DecodeResponse);
|
||||
/// Health check
|
||||
rpc Health (HealthRequest) returns (HealthResponse);
|
||||
}
|
||||
|
||||
message HealthRequest {}
|
||||
message HealthResponse {}
|
||||
|
||||
/// Empty request
|
||||
message InfoRequest {}
|
||||
|
||||
|
@ -173,4 +178,4 @@ message DecodeResponse {
|
|||
repeated Generation generations = 1;
|
||||
/// Next batch (cached)
|
||||
optional Batch batch = 2;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use tonic::transport::{Channel, Uri};
|
|||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
@ -62,6 +62,14 @@ impl Client {
|
|||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
|
|
|
@ -6,6 +6,7 @@ mod pb;
|
|||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v1::HealthResponse;
|
||||
pub use pb::generate::v1::InfoResponse as ShardInfo;
|
||||
pub use pb::generate::v1::{
|
||||
Batch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, PrefillTokens,
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
/// Multi shard Client
|
||||
use crate::Result;
|
||||
use crate::{Batch, Client, Generation, Request, ShardInfo};
|
||||
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
|
@ -48,6 +49,17 @@ impl ShardedClient {
|
|||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Health {
|
||||
client: ShardedClient,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Health {
|
||||
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards are answering gRPC calls
|
||||
self.client.health().await.is_ok()
|
||||
} else {
|
||||
// Generation is unhealthy or have not sent any generation request yet
|
||||
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
watermark: false,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
// Skips the queue
|
||||
let value = self.client.prefill(batch).await.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,7 +7,10 @@ use flume::SendError;
|
|||
use futures::future::try_join_all;
|
||||
use futures::stream::StreamExt;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::{
|
||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
|
@ -36,6 +39,7 @@ struct Shared {
|
|||
}
|
||||
|
||||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
validation: Validation,
|
||||
|
@ -44,6 +48,7 @@ impl Infer {
|
|||
max_waiting_tokens: usize,
|
||||
max_concurrent_requests: usize,
|
||||
requires_padding: bool,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let queue = Queue::new(requires_padding);
|
||||
|
@ -59,6 +64,7 @@ impl Infer {
|
|||
max_waiting_tokens,
|
||||
queue.clone(),
|
||||
shared.clone(),
|
||||
generation_health,
|
||||
));
|
||||
|
||||
// Inference limit with a semaphore
|
||||
|
@ -240,6 +246,7 @@ async fn batching_task(
|
|||
max_waiting_tokens: usize,
|
||||
queue: Queue,
|
||||
shared: Arc<Shared>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
|
@ -252,7 +259,7 @@ async fn batching_task(
|
|||
while let Some((mut entries, batch, span)) =
|
||||
queue.next_batch(None, max_batch_total_tokens).await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
@ -301,9 +308,10 @@ async fn batching_task(
|
|||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
|
@ -327,7 +335,7 @@ async fn batching_task(
|
|||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
|
@ -343,6 +351,7 @@ async fn prefill(
|
|||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<Batch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
|
@ -350,6 +359,8 @@ async fn prefill(
|
|||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
|
@ -362,6 +373,8 @@ async fn prefill(
|
|||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
|
@ -375,6 +388,7 @@ async fn decode(
|
|||
client: &mut ShardedClient,
|
||||
batches: Vec<Batch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<Batch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
|
@ -382,6 +396,8 @@ async fn decode(
|
|||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
|
@ -394,6 +410,7 @@ async fn decode(
|
|||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
mod health;
|
||||
/// Text Generation Inference Webserver
|
||||
mod infer;
|
||||
mod queue;
|
||||
|
@ -278,17 +279,21 @@ pub(crate) struct ErrorResponse {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests{
|
||||
mod tests {
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer{
|
||||
if !std::path::Path::new("tokenizer.json").exists(){
|
||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||
if !std::path::Path::new("tokenizer.json").exists() {
|
||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||
.await
|
||||
.unwrap()
|
||||
.bytes()
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||
file.write_all(&content).unwrap();
|
||||
}
|
||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -141,7 +141,6 @@ impl State {
|
|||
|
||||
// Get the next batch
|
||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
||||
|
||||
if self.entries.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::health::Health;
|
||||
/// HTTP Server logic
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
|
@ -18,6 +19,8 @@ use futures::Stream;
|
|||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{ShardInfo, ShardedClient};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::signal;
|
||||
|
@ -82,36 +85,29 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
|||
Json(info.0)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/health",
|
||||
responses(
|
||||
(status = 200, description = "Everything is working fine"),
|
||||
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
|
||||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(health))]
|
||||
/// Health check method
|
||||
#[instrument(skip(infer))]
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
|
||||
// be a bit too slow for a health check.
|
||||
// What we should do instead is check if the gRPC channels are still healthy.
|
||||
|
||||
// Send a small inference request
|
||||
infer
|
||||
.generate(GenerateRequest {
|
||||
inputs: "liveness".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: None,
|
||||
repetition_penalty: None,
|
||||
top_k: None,
|
||||
top_p: None,
|
||||
typical_p: None,
|
||||
do_sample: false,
|
||||
max_new_tokens: 1,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: false,
|
||||
seed: None,
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match health.check().await {
|
||||
true => Ok(()),
|
||||
false => Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(ErrorResponse {
|
||||
error: "unhealthy".to_string(),
|
||||
error_type: "healthcheck".to_string(),
|
||||
}),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tokens
|
||||
|
@ -555,6 +551,8 @@ pub async fn run(
|
|||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||
let infer = Infer::new(
|
||||
client,
|
||||
validation,
|
||||
|
@ -563,6 +561,7 @@ pub async fn run(
|
|||
max_waiting_tokens,
|
||||
max_concurrent_requests,
|
||||
shard_info.requires_padding,
|
||||
generation_health,
|
||||
);
|
||||
|
||||
// Duration buckets
|
||||
|
@ -657,6 +656,7 @@ pub async fn run(
|
|||
// Prometheus metrics route
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(prom_handle))
|
||||
|
@ -741,4 +741,3 @@ impl From<InferError> for Event {
|
|||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -380,111 +380,154 @@ pub enum ValidationError {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests{
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::default_parameters;
|
||||
use crate::tests::get_tokenizer;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_max_new_tokens(){
|
||||
async fn test_validation_max_new_tokens() {
|
||||
let tokenizer = None;
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens")
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_input_length(){
|
||||
async fn test_validation_input_length() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation.validate_input("Hello".to_string(), None, max_new_tokens).await{
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens")
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_best_of_sampling(){
|
||||
async fn test_validation_best_of_sampling() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
best_of: Some(2),
|
||||
do_sample: false,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: Some(2),
|
||||
do_sample: false,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::BestOfSampling) => (),
|
||||
_ => panic!("Unexpected not best of sampling")
|
||||
_ => panic!("Unexpected not best of sampling"),
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validation_top_p(){
|
||||
async fn test_validation_top_p() {
|
||||
let tokenizer = Some(get_tokenizer().await);
|
||||
let max_best_of = 2;
|
||||
let max_stop_sequence = 3;
|
||||
let max_input_length = 4;
|
||||
let max_total_tokens = 5;
|
||||
let workers = 1;
|
||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
top_p: Some(1.0),
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_input_length,
|
||||
max_total_tokens,
|
||||
);
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(1.0),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::TopP) => (),
|
||||
_ => panic!("Unexpected top_p")
|
||||
_ => panic!("Unexpected top_p"),
|
||||
}
|
||||
|
||||
match validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await{
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
_ => panic!("Unexpected top_p error")
|
||||
_ => panic!("Unexpected top_p error"),
|
||||
}
|
||||
|
||||
let valid_request = validation.validate(GenerateRequest{
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters{
|
||||
top_p: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
}
|
||||
}).await.unwrap();
|
||||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: None,
|
||||
max_new_tokens: 1,
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,11 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
||||
async def Health(self, request, context):
|
||||
if self.model.device.type == "cuda":
|
||||
torch.zeros((2, 2)).cuda()
|
||||
return generate_pb2.HealthResponse()
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
|
||||
|
||||
|
|
Loading…
Reference in New Issue