continue refactoring
This commit is contained in:
parent
abf56b75a4
commit
56b16614de
|
@ -3,21 +3,43 @@ mod v3;
|
|||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::ShardInfo;
|
||||
use serde::Serialize;
|
||||
use std::fmt::Debug;
|
||||
use thiserror::Error;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait Scheduler {
|
||||
pub(crate) trait Backend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, SchedulerError>>, SchedulerError>;
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, BackendError>>, BackendError>;
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub(crate) struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn connect_backend(
|
||||
master_shard_uds_path: String,
|
||||
|
@ -28,8 +50,8 @@ pub(crate) async fn connect_backend(
|
|||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(Arc<dyn Scheduler + Send + Sync>, ShardInfo, u32), SchedulerError> {
|
||||
v3::connect_backend(
|
||||
) -> Result<(impl Backend, BackendInfo), BackendError> {
|
||||
let (backend, info) = v3::connect_backend(
|
||||
master_shard_uds_path,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
|
@ -40,15 +62,15 @@ pub(crate) async fn connect_backend(
|
|||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| SchedulerError::Startup(Box::new(err)))
|
||||
.map_err(|err| BackendError::Startup(Box::new(err)))?;
|
||||
|
||||
Ok((backend, info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SchedulerError {
|
||||
pub enum BackendError {
|
||||
#[error("Startup error: {0}")]
|
||||
Startup(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("Request failed during generation: {0}")]
|
||||
Generation(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("Backend error: {0}")]
|
||||
Backend(Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::schedulers::v2::queue::{Entry, Queue};
|
||||
use crate::infer::backends::v2::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
|
@ -18,14 +18,14 @@ use tokio::time::Instant;
|
|||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub(crate) struct SchedulerV2 {
|
||||
pub(crate) struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl SchedulerV2 {
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
|
@ -62,7 +62,7 @@ impl SchedulerV2 {
|
|||
}
|
||||
}
|
||||
|
||||
impl Scheduler for SchedulerV2 {
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
|
@ -0,0 +1,4 @@
|
|||
mod backend;
|
||||
mod queue;
|
||||
|
||||
pub(crate) use backend::BackendV2;
|
|
@ -1,30 +1,34 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::schedulers::v3::queue::{Entry, Queue};
|
||||
use crate::infer::schedulers::SchedulerError;
|
||||
use crate::infer::{GeneratedText, InferStreamResponse, Scheduler};
|
||||
use crate::infer::backends::v3::queue::{Entry, Queue};
|
||||
use crate::infer::backends::BackendError;
|
||||
use crate::infer::{Backend, GeneratedText, InferStreamResponse};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
||||
use text_generation_client::{ClientError, Health};
|
||||
use text_generation_client::{ClientError, Health, ShardInfo};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub(crate) struct SchedulerV3 {
|
||||
pub(crate) struct BackendV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// State
|
||||
state: Arc<State>,
|
||||
}
|
||||
|
||||
struct State {
|
||||
batching_task_notifier: Notify,
|
||||
/// Client, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl SchedulerV3 {
|
||||
impl BackendV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
|
@ -33,10 +37,15 @@ impl SchedulerV3 {
|
|||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
shard_info: ShardInfo,
|
||||
) -> Self {
|
||||
let ShardInfo {
|
||||
requires_padding,
|
||||
window_size,
|
||||
speculate,
|
||||
..
|
||||
} = shard_info;
|
||||
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
|
@ -44,35 +53,34 @@ impl SchedulerV3 {
|
|||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
let batching_task_notifier = Notify::new();
|
||||
let state = Arc::new(State {
|
||||
batching_task_notifier,
|
||||
client,
|
||||
});
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
state.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
Self { queue, state }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scheduler for SchedulerV3 {
|
||||
impl Backend for BackendV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, SchedulerError>>, SchedulerError>
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, BackendError>>, BackendError>
|
||||
{
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
@ -90,7 +98,7 @@ impl Scheduler for SchedulerV3 {
|
|||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
self.state.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
|
@ -99,9 +107,9 @@ impl Scheduler for SchedulerV3 {
|
|||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
self.state.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
self.state.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
|
@ -112,20 +120,21 @@ impl Scheduler for SchedulerV3 {
|
|||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
async fn batching_task(
|
||||
state: Arc<State>,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
let mut client = state.client.clone();
|
||||
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
state.batching_task_notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
|
@ -369,7 +378,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, SchedulerError>>>> {
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, BackendError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
|
@ -462,7 +471,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = SchedulerError::Generation(Box::new(error.clone()));
|
||||
let err = BackendError::Generation(Box::new(error.clone()));
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
tracing::error!("{err}");
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
mod backend;
|
||||
mod block_allocator;
|
||||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
use crate::infer::schedulers::v3::scheduler::SchedulerV3;
|
||||
use crate::infer::schedulers::Scheduler;
|
||||
use std::sync::Arc;
|
||||
use crate::infer::backends::v3::backend::BackendV3;
|
||||
use crate::infer::backends::BackendInfo;
|
||||
use text_generation_client::v3::ShardedClient;
|
||||
use text_generation_client::{ClientError, ShardInfo};
|
||||
use text_generation_client::ClientError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
@ -19,7 +18,7 @@ pub(crate) async fn connect_backend(
|
|||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(Arc<dyn Scheduler + Send + Sync>, ShardInfo, u32), V3Error> {
|
||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
|
@ -77,21 +76,31 @@ pub(crate) async fn connect_backend(
|
|||
.await
|
||||
.map_err(V3Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let scheduler = Arc::new(SchedulerV3::new(
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
));
|
||||
tracing::info!("Using scheduler V3");
|
||||
shard_info,
|
||||
);
|
||||
|
||||
Ok((scheduler, shard_info, max_batch_total_tokens))
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
|
@ -1,5 +1,5 @@
|
|||
use crate::infer::schedulers::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::infer::schedulers::SchedulerError;
|
||||
use crate::infer::backends::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::infer::backends::BackendError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
|
@ -22,7 +22,7 @@ pub(crate) struct Entry {
|
|||
/// Request
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, SchedulerError>>,
|
||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, BackendError>>,
|
||||
/// Span that will live as long as entry
|
||||
pub span: Span,
|
||||
/// Temporary span used as a guard when logging inference, wait times...
|
||||
|
@ -463,7 +463,7 @@ mod tests {
|
|||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, SchedulerError>>,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, BackendError>>,
|
||||
) {
|
||||
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
pub(crate) mod backends;
|
||||
mod chat_template;
|
||||
pub(crate) mod schedulers;
|
||||
mod tool_grammar;
|
||||
|
||||
pub(crate) use tool_grammar::ToolGrammar;
|
||||
|
@ -11,11 +11,11 @@ use crate::{
|
|||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
};
|
||||
pub(crate) use backends::{Backend, BackendInfo};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::ErrorKind;
|
||||
pub(crate) use schedulers::Scheduler;
|
||||
|
||||
use crate::infer::schedulers::SchedulerError;
|
||||
use crate::infer::backends::BackendError;
|
||||
use async_stream::stream;
|
||||
use futures::Stream;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
@ -31,8 +31,8 @@ use tracing::instrument;
|
|||
pub struct Infer {
|
||||
/// Validation
|
||||
validation: Validation,
|
||||
/// Request scheduler
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
/// Request backend
|
||||
backend: Arc<dyn Backend + Send + Sync>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
/// Inference limit
|
||||
|
@ -44,7 +44,7 @@ pub struct Infer {
|
|||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
validation: Validation,
|
||||
max_concurrent_requests: usize,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
|
@ -70,7 +70,7 @@ impl Infer {
|
|||
|
||||
Self {
|
||||
validation,
|
||||
scheduler,
|
||||
backend: Arc::new(backend),
|
||||
chat_template,
|
||||
limit_concurrent_requests: semaphore,
|
||||
backend_health,
|
||||
|
@ -110,9 +110,9 @@ impl Infer {
|
|||
|
||||
let input_length = valid_request.input_length;
|
||||
let mut generation_stream = self
|
||||
.scheduler
|
||||
.backend
|
||||
.schedule(valid_request)
|
||||
.map_err(InferError::Scheduler)?;
|
||||
.map_err(InferError::Backend)?;
|
||||
|
||||
let stream = stream! {
|
||||
while let Some(generation) = generation_stream.next().await {
|
||||
|
@ -280,7 +280,7 @@ impl Infer {
|
|||
#[instrument(skip(self))]
|
||||
pub(crate) async fn health(&self) -> bool {
|
||||
let health = self
|
||||
.scheduler
|
||||
.backend
|
||||
.health(self.backend_health.load(Ordering::SeqCst))
|
||||
.await;
|
||||
self.backend_health.store(health, Ordering::SeqCst);
|
||||
|
@ -332,9 +332,9 @@ pub(crate) struct InferResponse {
|
|||
#[derive(Debug, Error)]
|
||||
pub enum InferError {
|
||||
#[error("Request failed during scheduling: {0}")]
|
||||
Scheduler(SchedulerError),
|
||||
Backend(BackendError),
|
||||
#[error("Request failed during generation: {0}")]
|
||||
GenerationError(SchedulerError),
|
||||
GenerationError(BackendError),
|
||||
#[error("Model is overloaded")]
|
||||
Overloaded(#[from] TryAcquireError),
|
||||
#[error("Input validation error: {0}")]
|
||||
|
@ -350,7 +350,7 @@ pub enum InferError {
|
|||
impl InferError {
|
||||
pub(crate) fn error_type(&self) -> &str {
|
||||
match self {
|
||||
InferError::Scheduler(_) => "scheduler",
|
||||
InferError::Backend(_) => "backend",
|
||||
InferError::GenerationError(_) => "generation",
|
||||
InferError::Overloaded(_) => "overloaded",
|
||||
InferError::ValidationError(_) => "validation",
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::SchedulerV2;
|
|
@ -7,6 +7,7 @@ mod validation;
|
|||
#[cfg(feature = "kserve")]
|
||||
mod kserve;
|
||||
|
||||
use crate::infer::BackendInfo;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
use utoipa::ToSchema;
|
||||
|
@ -135,13 +136,9 @@ pub struct Info {
|
|||
pub model_id: String,
|
||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||
pub model_sha: Option<String>,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(nullable = true, example = "text-generation")]
|
||||
pub model_pipeline_tag: Option<String>,
|
||||
/// Router Parameters
|
||||
/// Shared Parameters
|
||||
#[schema(example = "128")]
|
||||
pub max_concurrent_requests: usize,
|
||||
#[schema(example = "2")]
|
||||
|
@ -152,14 +149,6 @@ pub struct Info {
|
|||
pub max_input_tokens: usize,
|
||||
#[schema(example = "2048")]
|
||||
pub max_total_tokens: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
#[schema(example = "32")]
|
||||
|
@ -173,6 +162,9 @@ pub struct Info {
|
|||
pub sha: Option<&'static str>,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub docker_label: Option<&'static str>,
|
||||
/// Backend parameters
|
||||
#[serde(flatten)]
|
||||
backend_info: BackendInfo,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
/// HTTP Server logic
|
||||
use crate::config::Config;
|
||||
use crate::infer::schedulers::{connect_backend, SchedulerError};
|
||||
use crate::infer::Scheduler;
|
||||
use crate::infer::backends::{connect_backend, BackendError};
|
||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
#[cfg(feature = "kserve")]
|
||||
use crate::kserve::{
|
||||
|
@ -38,8 +37,6 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
|||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::ShardInfo;
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
|
@ -1494,11 +1491,7 @@ pub async fn run(
|
|||
// Create state
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (scheduler, shard_info, max_batch_total_tokens): (
|
||||
Arc<dyn Scheduler + Send + Sync>,
|
||||
ShardInfo,
|
||||
u32,
|
||||
) = connect_backend(
|
||||
let (backend, backend_info) = connect_backend(
|
||||
master_shard_uds_path,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
|
@ -1509,8 +1502,7 @@ pub async fn run(
|
|||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(WebServerError::Scheduler)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
.map_err(WebServerError::Backend)?;
|
||||
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
|
@ -1525,7 +1517,7 @@ pub async fn run(
|
|||
);
|
||||
|
||||
let infer = Infer::new(
|
||||
scheduler,
|
||||
backend,
|
||||
validation,
|
||||
max_concurrent_requests,
|
||||
tokenizer_config,
|
||||
|
@ -1563,7 +1555,7 @@ pub async fn run(
|
|||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||
// Speculated tokens buckets
|
||||
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||
let skipped_buckets: Vec<f64> = (0..backend_info.speculate + 1).map(|x| x as f64).collect();
|
||||
|
||||
// Prometheus handler
|
||||
let builder = PrometheusBuilder::new()
|
||||
|
@ -1592,20 +1584,15 @@ pub async fn run(
|
|||
|
||||
// Endpoint info
|
||||
let info = Info {
|
||||
backend_info,
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
validation_workers,
|
||||
max_client_batch_size,
|
||||
router: env!("CARGO_PKG_NAME"),
|
||||
|
@ -1814,7 +1801,7 @@ impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
|
|||
InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||
InferError::Scheduler(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
InferError::Backend(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
(
|
||||
|
@ -1840,8 +1827,8 @@ impl From<InferError> for Event {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebServerError {
|
||||
#[error("Scheduler error: {0}")]
|
||||
Scheduler(#[from] SchedulerError),
|
||||
#[error("Backend error: {0}")]
|
||||
Backend(#[from] BackendError),
|
||||
#[error("Axum error: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue