fix(router): add timeout on flume sends (#488)
This commit is contained in:
parent
776d150c55
commit
bd3a9d8e85
|
@ -3,7 +3,7 @@ use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
use flume::r#async::RecvStream;
|
use flume::r#async::RecvStream;
|
||||||
use flume::SendError;
|
use flume::SendTimeoutError;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
|
@ -11,6 +11,7 @@ use std::sync::{
|
||||||
atomic::{AtomicBool, Ordering},
|
atomic::{AtomicBool, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
};
|
};
|
||||||
|
use std::time::Duration;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, CachedBatch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
|
@ -472,6 +473,10 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
// request and we need to stop generating hence why we unwrap_or(true)
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||||
|
if let SendTimeoutError::Timeout(_) = *err {
|
||||||
|
tracing::error!("Entry response channel timed out.")
|
||||||
|
}
|
||||||
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||||
err
|
err
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
|
@ -485,14 +490,20 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||||
fn send_responses(
|
fn send_responses(
|
||||||
generation: Generation,
|
generation: Generation,
|
||||||
entry: &Entry,
|
entry: &Entry,
|
||||||
) -> Result<bool, SendError<Result<InferStreamResponse, InferError>>> {
|
) -> Result<bool, Box<SendTimeoutError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
// Return directly if the channel is disconnected
|
||||||
|
if entry.response_tx.is_disconnected() {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
let mut stopped = false;
|
let mut stopped = false;
|
||||||
|
|
||||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry.response_tx.send_timeout(
|
||||||
.response_tx
|
Ok(InferStreamResponse::Prefill(prefill_tokens)),
|
||||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
Duration::from_millis(10),
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create last Token
|
// Create last Token
|
||||||
|
@ -507,17 +518,21 @@ fn send_responses(
|
||||||
// Generation has ended
|
// Generation has ended
|
||||||
stopped = true;
|
stopped = true;
|
||||||
// Send message
|
// Send message
|
||||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
entry.response_tx.send_timeout(
|
||||||
token,
|
Ok(InferStreamResponse::End {
|
||||||
generated_text,
|
token,
|
||||||
queued: entry.queue_time,
|
generated_text,
|
||||||
start: entry.batch_time.unwrap(),
|
queued: entry.queue_time,
|
||||||
}))?;
|
start: entry.batch_time.unwrap(),
|
||||||
|
}),
|
||||||
|
Duration::from_millis(10),
|
||||||
|
)?;
|
||||||
} else {
|
} else {
|
||||||
// Send message
|
// Send message
|
||||||
entry
|
entry.response_tx.send_timeout(
|
||||||
.response_tx
|
Ok(InferStreamResponse::Token(token)),
|
||||||
.send(Ok(InferStreamResponse::Token(token)))?;
|
Duration::from_millis(10),
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
Ok(stopped)
|
Ok(stopped)
|
||||||
}
|
}
|
||||||
|
@ -535,7 +550,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
entry
|
entry
|
||||||
.response_tx
|
.response_tx
|
||||||
.send(Err(err))
|
.send_timeout(Err(err), Duration::from_millis(10))
|
||||||
.unwrap_or(());
|
.unwrap_or(());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,7 +95,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
|
||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => span.in_scope(|| {
|
||||||
let next_batch = state.next_batch(min_size, token_budget);
|
let next_batch = state.next_batch(min_size, token_budget);
|
||||||
response_sender.send(next_batch).unwrap_or(());
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue