feat: remove flume (#1184)
This commit is contained in:
parent
12590fdcce
commit
f9910d13e2
|
@ -743,18 +743,6 @@ version = "1.0.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28a80e3145d8ad11ba0995949bbcf48b9df2be62772b3d351ef017dff6ecb853"
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"nanorand",
|
||||
"spin 0.9.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
|
@ -900,10 +888,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1508,15 +1494,6 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nanorand"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.11"
|
||||
|
@ -2313,7 +2290,7 @@ dependencies = [
|
|||
"cc",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"spin 0.5.2",
|
||||
"spin",
|
||||
"untrusted",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
|
@ -2678,15 +2655,6 @@ version = "0.5.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spm_precompiled"
|
||||
version = "0.1.4"
|
||||
|
@ -2808,7 +2776,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
|
@ -2829,7 +2797,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"grpc-metadata",
|
||||
|
@ -2845,7 +2813,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
|
@ -2861,13 +2829,12 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
"axum-tracing-opentelemetry",
|
||||
"clap",
|
||||
"flume",
|
||||
"futures",
|
||||
"hf-hub 0.3.1",
|
||||
"init-tracing-opentelemetry",
|
||||
|
@ -2885,6 +2852,7 @@ dependencies = [
|
|||
"thiserror",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
|
|
|
@ -20,7 +20,6 @@ axum = { version = "0.6.20", features = ["json"] }
|
|||
axum-tracing-opentelemetry = "0.14.1"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
flume = "0.11.0"
|
||||
futures = "0.3.28"
|
||||
metrics = "0.21.1"
|
||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||
|
@ -34,6 +33,7 @@ serde_json = "1.0.107"
|
|||
thiserror = "1.0.48"
|
||||
tokenizers = { version = "0.14.0", features = ["http"] }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.4.4", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
|
|
|
@ -107,15 +107,14 @@ impl Client {
|
|||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
let mut truncate = 0;
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
truncate: truncate,
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
|
|
@ -2,22 +2,21 @@
|
|||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use flume::r#async::RecvStream;
|
||||
use flume::SendTimeoutError;
|
||||
use futures::future::try_join_all;
|
||||
use futures::stream::StreamExt;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
/// Inference struct
|
||||
|
@ -90,7 +89,7 @@ impl Infer {
|
|||
) -> Result<
|
||||
(
|
||||
OwnedSemaphorePermit,
|
||||
RecvStream<Result<InferStreamResponse, InferError>>,
|
||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
||||
),
|
||||
InferError,
|
||||
> {
|
||||
|
@ -113,7 +112,7 @@ impl Infer {
|
|||
})?;
|
||||
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = flume::unbounded();
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
|
@ -130,7 +129,7 @@ impl Infer {
|
|||
self.shared.batching_task.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok((permit, response_rx.into_stream()))
|
||||
Ok((permit, UnboundedReceiverStream::new(response_rx)))
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a InferResponse
|
||||
|
@ -493,10 +492,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
if let SendTimeoutError::Timeout(_) = *err {
|
||||
tracing::error!("Entry response channel timed out.")
|
||||
}
|
||||
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
|
@ -510,9 +506,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_disconnected() {
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
|
@ -520,10 +517,9 @@ fn send_responses(
|
|||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Prefill(prefill_tokens)),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
|
@ -558,22 +554,18 @@ fn send_responses(
|
|||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
} else {
|
||||
// Send message
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Intermediate { token, top_tokens }),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
Ok(stopped)
|
||||
}
|
||||
|
@ -591,7 +583,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send_timeout(Err(err), Duration::from_millis(10))
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use nohash_hasher::{BuildNoHashHasher, IntMap};
|
|||
use std::cmp::min;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::{Batch, Request};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Span};
|
||||
|
||||
|
@ -15,7 +15,7 @@ pub(crate) struct Entry {
|
|||
/// Request
|
||||
pub request: ValidGenerateRequest,
|
||||
/// Response sender to communicate between the Infer struct and the batching_task
|
||||
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
||||
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
|
||||
/// Span that will live as long as entry
|
||||
pub span: Span,
|
||||
/// Temporary span used as a guard when logging inference, wait times...
|
||||
|
@ -30,13 +30,13 @@ pub(crate) struct Entry {
|
|||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Queue {
|
||||
/// Channel to communicate with the background queue task
|
||||
queue_sender: flume::Sender<QueueCommand>,
|
||||
queue_sender: mpsc::UnboundedSender<QueueCommand>,
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
|
||||
// Launch background queue task
|
||||
tokio::spawn(queue_task(
|
||||
|
@ -91,11 +91,11 @@ async fn queue_task(
|
|||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
window_size: Option<u32>,
|
||||
receiver: flume::Receiver<QueueCommand>,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(requires_padding, block_size, window_size);
|
||||
|
||||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
|
@ -195,7 +195,7 @@ impl State {
|
|||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_disconnected() {
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
continue;
|
||||
}
|
||||
|
@ -321,9 +321,9 @@ mod tests {
|
|||
|
||||
fn default_entry() -> (
|
||||
Entry,
|
||||
flume::Receiver<Result<InferStreamResponse, InferError>>,
|
||||
mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
|
||||
) {
|
||||
let (response_tx, receiver_tx) = flume::unbounded();
|
||||
let (response_tx, receiver_tx) = mpsc::unbounded_channel();
|
||||
|
||||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
|
|
|
@ -6,6 +6,7 @@ use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParamet
|
|||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
|
||||
|
@ -19,7 +20,7 @@ pub struct Validation {
|
|||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
/// Channel to communicate with the background tokenization task
|
||||
sender: Option<flume::Sender<TokenizerRequest>>,
|
||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||
}
|
||||
|
||||
impl Validation {
|
||||
|
@ -34,19 +35,25 @@ impl Validation {
|
|||
) -> Self {
|
||||
// If we have a fast tokenizer
|
||||
let sender = if let Some(tokenizer) = tokenizer {
|
||||
// Create channel
|
||||
let (validation_sender, validation_receiver) = flume::unbounded();
|
||||
// Create round robin channel
|
||||
let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();
|
||||
let mut senders = Vec::with_capacity(workers);
|
||||
|
||||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let receiver_clone = validation_receiver.clone();
|
||||
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||
senders.push(tokenizer_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone, receiver_clone)
|
||||
tokenizer_worker(tokenizer_clone, tokenizer_receiver)
|
||||
});
|
||||
}
|
||||
|
||||
// Create tokenization round robin task
|
||||
tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));
|
||||
|
||||
Some(validation_sender)
|
||||
} else {
|
||||
None
|
||||
|
@ -118,12 +125,10 @@ impl Validation {
|
|||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens {
|
||||
max_new_tokens
|
||||
} else if let Some(truncate) = truncate {
|
||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||
} else {
|
||||
if let Some(truncate) = truncate {
|
||||
self.max_total_tokens.saturating_sub(truncate) as u32
|
||||
} else {
|
||||
return Err(ValidationError::UnsetMaxNewTokens);
|
||||
}
|
||||
return Err(ValidationError::UnsetMaxNewTokens);
|
||||
};
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
|
||||
|
@ -309,10 +314,25 @@ impl Validation {
|
|||
}
|
||||
}
|
||||
|
||||
/// Round robin tokenization task
|
||||
async fn round_robin_task(
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||
) {
|
||||
loop {
|
||||
for sender in &senders {
|
||||
match receiver.recv().await {
|
||||
None => return,
|
||||
Some(request) => sender.send(request).unwrap(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start tokenization workers
|
||||
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
|
||||
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
||||
// Loop over requests
|
||||
while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer))
|
||||
|
|
Loading…
Reference in New Issue