fix(router): add timeout on flume sends (#488)
This commit is contained in:
parent
776d150c55
commit
bd3a9d8e85
|
@ -3,7 +3,7 @@ use crate::validation::{Validation, ValidationError};
|
|||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use flume::r#async::RecvStream;
|
||||
use flume::SendError;
|
||||
use flume::SendTimeoutError;
|
||||
use futures::future::try_join_all;
|
||||
use futures::stream::StreamExt;
|
||||
use nohash_hasher::IntMap;
|
||||
|
@ -11,6 +11,7 @@ use std::sync::{
|
|||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use text_generation_client::{
|
||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||
};
|
||||
|
@ -472,6 +473,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
if let SendTimeoutError::Timeout(_) = *err {
|
||||
tracing::error!("Entry response channel timed out.")
|
||||
}
|
||||
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
|
@ -485,14 +490,20 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
||||
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_disconnected() {
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Prefill(prefill_tokens)),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
|
@ -507,17 +518,21 @@ fn send_responses(
|
|||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
generated_text,
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
} else {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Token(token)))?;
|
||||
entry.response_tx.send_timeout(
|
||||
Ok(InferStreamResponse::Token(token)),
|
||||
Duration::from_millis(10),
|
||||
)?;
|
||||
}
|
||||
Ok(stopped)
|
||||
}
|
||||
|
@ -535,7 +550,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.send_timeout(Err(err), Duration::from_millis(10))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
|
|
@ -95,7 +95,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
|
|||
span,
|
||||
} => span.in_scope(|| {
|
||||
let next_batch = state.next_batch(min_size, token_budget);
|
||||
response_sender.send(next_batch).unwrap_or(());
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue