fix(router): add timeout on flume sends (#488)

This commit is contained in:
OlivierDehaene 2023-06-23 14:58:28 +02:00 committed by GitHub
parent 776d150c55
commit bd3a9d8e85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 16 deletions

View File

@ -3,7 +3,7 @@ use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream; use flume::r#async::RecvStream;
use flume::SendError; use flume::SendTimeoutError;
use futures::future::try_join_all; use futures::future::try_join_all;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
@ -11,6 +11,7 @@ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
Arc, Arc,
}; };
use std::time::Duration;
use text_generation_client::{ use text_generation_client::{
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient, Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
}; };
@ -472,6 +473,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// If the receive an error from the Flume channel, it means that the client dropped the // If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true) // request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| { let stopped = send_responses(generation, entry).map_err(|err| {
if let SendTimeoutError::Timeout(_) = *err {
tracing::error!("Entry response channel timed out.")
}
metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err err
}).unwrap_or(true); }).unwrap_or(true);
@ -485,14 +490,20 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
fn send_responses( fn send_responses(
generation: Generation, generation: Generation,
entry: &Entry, entry: &Entry,
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> { ) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_disconnected() {
return Ok(true);
}
let mut stopped = false; let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
entry entry.response_tx.send_timeout(
.response_tx Ok(InferStreamResponse::Prefill(prefill_tokens)),
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?; Duration::from_millis(10),
)?;
} }
// Create last Token // Create last Token
@ -507,17 +518,21 @@ fn send_responses(
// Generation has ended // Generation has ended
stopped = true; stopped = true;
// Send message // Send message
entry.response_tx.send(Ok(InferStreamResponse::End { entry.response_tx.send_timeout(
token, Ok(InferStreamResponse::End {
generated_text, token,
queued: entry.queue_time, generated_text,
start: entry.batch_time.unwrap(), queued: entry.queue_time,
}))?; start: entry.batch_time.unwrap(),
}),
Duration::from_millis(10),
)?;
} else { } else {
// Send message // Send message
entry entry.response_tx.send_timeout(
.response_tx Ok(InferStreamResponse::Token(token)),
.send(Ok(InferStreamResponse::Token(token)))?; Duration::from_millis(10),
)?;
} }
Ok(stopped) Ok(stopped)
} }
@ -535,7 +550,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send(Err(err)) .send_timeout(Err(err), Duration::from_millis(10))
.unwrap_or(()); .unwrap_or(());
}); });
} }

View File

@ -95,7 +95,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
span, span,
} => span.in_scope(|| { } => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget); let next_batch = state.next_batch(min_size, token_budget);
response_sender.send(next_batch).unwrap_or(()); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }),
} }