parent
9263817c71
commit
10e6f29295
|
@ -4175,7 +4175,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "2.2.1-dev0"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -4198,7 +4198,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "2.2.1-dev0"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap 4.5.17",
|
||||
|
@ -4219,7 +4219,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "2.2.1-dev0"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
|
@ -4237,7 +4237,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.2.1-dev0"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"clap 4.5.17",
|
||||
"ctrlc",
|
||||
|
@ -4256,7 +4256,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "2.2.1-dev0"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
@ -4303,9 +4303,58 @@ dependencies = [
|
|||
"vergen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v2"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.5",
|
||||
"axum-tracing-opentelemetry",
|
||||
"base64 0.22.1",
|
||||
"clap 4.5.17",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"grpc-metadata",
|
||||
"hf-hub",
|
||||
"image",
|
||||
"init-tracing-opentelemetry",
|
||||
"jsonschema",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"minijinja",
|
||||
"minijinja-contrib",
|
||||
"nohash-hasher",
|
||||
"once_cell",
|
||||
"opentelemetry 0.20.0",
|
||||
"opentelemetry-otlp",
|
||||
"prost 0.12.6",
|
||||
"prost-build",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"slotmap",
|
||||
"text-generation-router",
|
||||
"thiserror",
|
||||
"tokenizers 0.20.0",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic 0.10.2",
|
||||
"tonic-build",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry 0.21.0",
|
||||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
"utoipa-swagger-ui",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-router-v3"
|
||||
version = "2.2.1-dev0"
|
||||
version = "2.3.1-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
|
|
|
@ -1,19 +1,19 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"backends/v2",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
"backends/trtllm",
|
||||
"backends/client",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
default-members = [
|
||||
"benchmark",
|
||||
"backends/v2",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
# "backends/trtllm",
|
||||
"backends/client",
|
||||
"launcher",
|
||||
"router"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
[package]
|
||||
name = "text-generation-router-v2"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
slotmap = "1.0.7"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"signal",
|
||||
"sync",
|
||||
] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||
"opentelemetry-otlp",
|
||||
] }
|
||||
minijinja = { workspace = true }
|
||||
minijinja-contrib = { workspace = true }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
|
@ -0,0 +1,19 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,506 @@
|
|||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
// Infer shared state
|
||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||
attention
|
||||
.parse()
|
||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||
} else {
|
||||
Attention::Paged
|
||||
};
|
||||
let block_size = if attention == Attention::FlashDecoding {
|
||||
256
|
||||
} else {
|
||||
16
|
||||
};
|
||||
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v2_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,257 @@
|
|||
/// Single shard Client
|
||||
use crate::client::pb;
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v2::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
//! Text Generation gRPC client library
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod grpc_client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use grpc_client::Client;
|
||||
pub use pb::generate::v2::{
|
||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
let err = Self::Generation(err.message().to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
let err = Self::Connection(err.to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
|
@ -0,0 +1,252 @@
|
|||
/// Multi shard Client
|
||||
use crate::client::{ClientError, Result};
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::InfoResponse;
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await?;
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||
Ok(results.into_iter().flatten().min())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,141 @@
|
|||
mod backend;
|
||||
mod client;
|
||||
mod queue;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV2;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn connect_backend(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
master_shard_uds_path: String,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens
|
||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V2Error::Connection)?;
|
||||
|
||||
// server is running on v2
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V2Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V2Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V2Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::{server, usage_stats};
|
||||
use text_generation_router_v2::{connect_backend, V2Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(default_value = "on", long, env)]
|
||||
usage_stats: usage_stats::UsageStatsLevel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
api_key,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
use utoipa::OpenApi;
|
||||
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
};
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_batch_size` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
usage_stats,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V2Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
|
@ -1,14 +1,14 @@
|
|||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v2::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
};
|
||||
use text_generation_client::ChunksToString;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
|
@ -218,7 +218,7 @@ impl State {
|
|||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
next_batch_span.follows_from(Span::current());
|
||||
|
||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||
let mut batch_entries =
|
||||
|
@ -404,6 +404,7 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use tracing::info_span;
|
||||
|
||||
fn default_entry() -> (
|
||||
|
@ -415,7 +416,9 @@ mod tests {
|
|||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_ids: Some(Arc::new(vec![])),
|
||||
input_length: 0,
|
||||
add_special_tokens: true,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: ValidParameters {
|
|
@ -8,9 +8,11 @@ use crate::{
|
|||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
};
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use chat_template::ChatTemplate;
|
||||
use futures::future::try_join_all;
|
||||
use futures::Stream;
|
||||
use minijinja::ErrorKind;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
@ -87,7 +89,14 @@ impl Infer {
|
|||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
let permit = self
|
||||
.clone()
|
||||
|
@ -107,9 +116,18 @@ impl Infer {
|
|||
})?;
|
||||
|
||||
let input_length = valid_request.input_length;
|
||||
let generation_stream = self.backend.schedule(valid_request)?;
|
||||
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
Ok((permit, input_length, generation_stream))
|
||||
// Wrap generation stream to update the backend health if the stream contains an error
|
||||
let final_stream = stream! {
|
||||
while let Some(response) = generation_stream.next().await {
|
||||
yield response.inspect_err(|_err| {
|
||||
self.backend_health.store(false, Ordering::SeqCst);
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok((permit, input_length, final_stream))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
|
@ -278,13 +296,6 @@ impl Infer {
|
|||
}
|
||||
}
|
||||
|
||||
/// Type alias for generation responses
|
||||
pub(crate) type GenerateStreamResponse = (
|
||||
OwnedSemaphorePermit,
|
||||
u32, // input_length
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::BackendV2;
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue