fix(router): add timeout on flume sends (#488)

This commit is contained in:
OlivierDehaene 2023-06-23 14:58:28 +02:00 committed by GitHub
parent 776d150c55
commit bd3a9d8e85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 16 deletions

View File

@ -3,7 +3,7 @@ use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
use flume::SendError;
use flume::SendTimeoutError;
use futures::future::try_join_all;
use futures::stream::StreamExt;
use nohash_hasher::IntMap;
@ -11,6 +11,7 @@ use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::time::Duration;
use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
};
@ -472,6 +473,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
if let SendTimeoutError::Timeout(_) = *err {
tracing::error!("Entry response channel timed out.")
}
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}).unwrap_or(true);
@ -485,14 +490,20 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_disconnected() {
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Prefill(prefill_tokens)),
Duration::from_millis(10),
)?;
}
// Create last Token
@ -507,17 +518,21 @@ fn send_responses(
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
entry.response_tx.send_timeout(
Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}),
Duration::from_millis(10),
)?;
} else {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Token(token)))?;
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)),
Duration::from_millis(10),
)?;
}
Ok(stopped)
}
@ -535,7 +550,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.send_timeout(Err(err), Duration::from_millis(10))
.unwrap_or(());
});
}

View File

@ -95,7 +95,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget);
response_sender.send(next_batch).unwrap_or(());
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}),
}