make sure the context is not dropped in the middle of the async decoding.

This commit is contained in:
Morgan Funtowicz 2024-07-17 21:56:50 +00:00
parent 9220340ff7
commit e983ee5bb8
8 changed files with 192 additions and 192 deletions

View File

@ -6,19 +6,19 @@ authors.workspace = true
homepage.workspace = true homepage.workspace = true
[dependencies] [dependencies]
async-trait = "0.1.74" async-trait = "0.1"
async-stream = "0.3.5" async-stream = "0.3"
cxx = "1.0" cxx = "1.0"
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] } tokenizers = { version = "0.19", features = ["hf-hub"] }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14" tokio-stream = "0.1.15"
clap = { version = "4.5.4", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
thiserror = "1.0.61" thiserror = "1.0.62"
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.24" tracing-opentelemetry = "0.24"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
log = { version = "0.4.21", features = [] } log = { version = "0.4", features = [] }
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -50,8 +50,7 @@ namespace huggingface::tgi::backends {
uint32_t topK, uint32_t topK,
float_t topP, float_t topP,
float_t temperature, float_t temperature,
uint64_t seed, uint64_t seed
std::optional<int32_t> beamWidth
); );
/** /**

View File

@ -56,8 +56,8 @@ namespace huggingface::tgi::backends {
*/ */
size_t StreamTokens( size_t StreamTokens(
const RequestId requestId, const RequestId requestId,
rust::Box<huggingface::tgi::backends::GenerationContext> ctx, huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback); rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback);
}; };
/*** /***

View File

@ -57,10 +57,9 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK, uint32_t topK,
float_t topP, float_t topP,
float_t temperature, float_t temperature,
uint64_t seed, uint64_t seed) {
std::optional<int32_t> beamWidth = std::nullopt) {
return tle::SamplingConfig( return tle::SamplingConfig(
beamWidth.value_or(1), 1, // TGI only use a single beam
topK, topK,
topP, topP,
std::nullopt, std::nullopt,
@ -116,11 +115,11 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
); );
#endif #endif
const auto maxNumTokens = config["max_num_tokens"_json_pointer].get<size_t>(); const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size())); const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, seed); const auto sampling = GetSamplingConfig(topK, topP, temperature, seed);
const auto output = tle::OutputConfig(false, false, false, true, false); const auto output = tle::OutputConfig(true, false, false, true, false);
return executor.enqueueRequest( return executor.enqueueRequest(
tle::Request{tokens, maxNewTokens, true, sampling, output}); tle::Request{tokens, maxNewTokens, true, sampling, output});
} }

View File

@ -8,7 +8,7 @@ use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use log::{info, warn}; use log::{debug, info, warn};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -19,7 +19,8 @@ use tracing::{instrument, Level, span};
use text_generation_router::{FinishReason, Token}; use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use crate::errors::TensorRtLlmBackendError; use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
@ -55,10 +56,12 @@ pub(crate) struct Generation {
done: Arc<AtomicBool>, done: Arc<AtomicBool>,
} }
pub struct GenerationContext( #[derive(Clone)]
UnboundedSender<InferResult<InferStreamResponse>>, pub struct GenerationContext {
Arc<AtomicBool>, sender: UnboundedSender<InferResult<InferStreamResponse>>,
); tokenizer: Arc<Tokenizer>,
done: Arc<AtomicBool>,
}
impl Stream for Generation { impl Stream for Generation {
type Item = usize; type Item = usize;
@ -110,24 +113,160 @@ unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background /// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend { pub struct TensorRtLlmBackend {
// Allowing sending user requests to the TensorRT-LLM executor thread tokenizer: Arc<Tokenizer>,
// batcher: UnboundedSender<InferenceContext>,
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>, backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
} }
impl TensorRtLlmBackend { impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>( pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
_tokenizer: Tokenizer, tokenizer: Tokenizer,
engine_folder: P, engine_folder: P,
_executor_worker_path: Option<PP>, _executor_worker_path: Option<PP>,
) -> Result<Self, TensorRtLlmBackendError> { ) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend { Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(), engine_folder.as_ref().to_str().unwrap(),
"", "",
))), ))),
}) })
} }
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
seed: u64,
) {
let tokenizer = self.tokenizer.clone();
let executor = self.backend.clone();
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer: tokenizer.clone(),
done: Arc::clone(&generation.done),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
debug!("Acquiring lock for submit");
let mut handle = executor.write().await;
let request_id =
handle
.pin_mut()
.submit(&tokens, top_k as i32, top_p, temperature, seed);
debug!("Releasing lock for submit");
request_id
})
.await;
while let Some(_) = generation.next().await {
span!(Level::DEBUG, "decode", request_id = request_id)
.in_scope(|| async {
let mut executor_w = executor.write().await;
unsafe {
debug!("Acquired write lock stream");
executor_w.pin_mut().stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext,
token: u32,
logprob: f32,
is_final: bool| {
// let text = ctx
// .tokenizer
// .decode(&[token], true)
// .expect("Failed to decode token");
info!("Decoded token: {}", token);
let out = if is_final {
(*ctx).done.store(true, Ordering::Relaxed);
InferStreamResponse::End {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: "".into(),
generated_tokens: u32::MAX,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: Instant::now(),
queued: Instant::now(),
}
} else {
InferStreamResponse::Intermediate {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
}
};
(*ctx)
.sender
.send(Ok(out))
.expect("Failed to send back generated token");
},
);
debug!("Releasing write lock stream")
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
} }
#[async_trait] #[async_trait]
@ -135,96 +274,32 @@ impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)] #[instrument(skip_all)]
fn schedule( fn schedule(
&self, &self,
_request: ValidGenerateRequest, request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> { ) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer // Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel(); let (sender, receiver) = unbounded_channel();
let executor = self.backend.clone(); // Unpack parameters
tokio::spawn(async move { let params = &request.parameters;
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]")
.in_scope(|| async {
info!("Acquiring lock for submit");
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&vec![2, 2926, 1503, 603, 20189],
50,
1.0,
1.0,
2014,
);
info!("Releasing lock for submit"); // Preprocess the inputs to send to TRTLLM backend
request_id let encoding = self
}) .tokenizer
.await; .encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
let mut generation = Generation { // Generate the response
executor: executor.clone(), self.generate(
done: Arc::new(AtomicBool::new(false)), sender,
}; Vec::from(encoding.get_ids()),
params.top_k,
while let Some(num_tokens_ready) = generation.next().await { params.top_p,
span!( params.temperature,
Level::DEBUG, params.seed,
"[EXECUTOR][GENERATE]", );
request_id = request_id,
num_tokens_ready = num_tokens_ready
)
.in_scope(|| async {
let ctx = Box::new(GenerationContext(
sender.clone(),
Arc::clone(&generation.done),
));
let mut executor_w = executor.write().await;
info!("Acquired write lock stream");
executor_w.pin_mut().stream_tokens(
request_id,
ctx,
|ctx: Box<GenerationContext>, token: u32, logprob: f32, is_final: bool| {
info!("Sending token: {} (final: {})", token, is_final);
let out = if is_final {
ctx.1.store(true, Ordering::Relaxed);
InferStreamResponse::End {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: "".into(),
generated_tokens: u32::MAX,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: Instant::now(),
queued: Instant::now(),
}
} else {
InferStreamResponse::Intermediate {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
}
};
ctx.0
.send(Ok(out))
.expect("Failed to send back generated token");
},
);
info!("Releasing write lock stream")
})
.await;
}
});
Ok(UnboundedReceiverStream::new(receiver)) Ok(UnboundedReceiverStream::new(receiver))
} }
@ -233,79 +308,3 @@ impl Backend for TensorRtLlmBackend {
true true
} }
} }
// async fn background_looper<P: AsRef<Path>, PP: AsRef<Path>>(
// engine_folder: P,
// _executor_worker: Option<PP>,
// tokenizer: Tokenizer,
// mut receiver: UnboundedReceiver<InferenceContext>,
// ) {
// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), "");
//
// while !(receiver.is_closed()) {
// // Receive the incoming request
// if let Some(ctx) = receiver.recv().await {
// debug!("Processing new incoming request");
//
// // We only support single, textual chunk
// if ctx.request.inputs.len() != 1 {
// propagate!(
// ctx,
// Err(InferError::GenerationError(format!(
// "Unsupported multi-chunk ({}) input",
// ctx.request.inputs.len()
// )))
// );
// }
//
// let input = ctx
// .request
// .inputs
// .first()
// .expect("Single chunk checked above");
// let params = ctx.request.parameters;
// }
// }
// Receive the incoming request
// if let Some(ctx) = receiver.recv().await {
// debug!("Processing new incoming request");
// // We only support single, textual chunk
// if ctx.request.inputs.len() != 1 {
// propagate!(
// ctx,
// Err(InferError::GenerationError(format!(
// "Unsupported multi-chunk ({}) input",
// ctx.request.inputs.len()
// )))
// );
// }
//
// // Unpack parameters
// let inputs = ctx.request.inputs;
// let params = ctx.request.parameters;
//
// match inputs.first().unwrap() {
// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) {
// Err(err) => {
// propagate!(ctx, Err(InferError::GenerationError(err.to_string())))
// }
// Ok(encoding) => {
// // spawn_blocking(|| {
// // info!("Submitting request to TensorRT-LLM executor");
// // let mut executor = backend.blocking_write();
// // })
// // .await
// // .expect("");
// }
// },
// Chunk::Image(_) => propagate!(
// ctx,
// Err(InferError::GenerationError(
// "Image input is not supported yet.".into(),
// ))
// ),
// }
// };
// }

View File

@ -33,8 +33,8 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId, const uint64_t requestId,
rust::Box<huggingface::tgi::backends::GenerationContext> ctx, huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback) { rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback) {
size_t numTokens = 0; size_t numTokens = 0;
for (const auto &item: Poll(requestId)) { for (const auto &item: Poll(requestId)) {
@ -44,12 +44,12 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const auto token = decoded.outputTokenIds[0][0]; const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal; const auto isFinal = decoded.isFinal;
const auto logProb = decoded.logProbs.value()[0][0]; // const auto logProb = decoded.logProbs.value()[0][0];
++numTokens; ++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
callback(std::move(ctx), token, logProb, isFinal); callback(std::move(ctx), token, 1.0, isFinal);
SPDLOG_DEBUG("\tStreamTokens -> Post callback"); SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else { } else {
// TODO : Return rest::Result with error // TODO : Return rest::Result with error

View File

@ -54,11 +54,11 @@ mod ffi {
) -> u64; ) -> u64;
#[rust_name = "stream_tokens"] #[rust_name = "stream_tokens"]
fn StreamTokens( unsafe fn StreamTokens(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64, request_id: u64,
ctx: Box<GenerationContext>, ctx: *mut GenerationContext,
cb: fn(Box<GenerationContext>, u32, f32, bool), cb: unsafe fn(*mut GenerationContext, u32, f32, bool),
) -> usize; ) -> usize;
// #[rust_name = "shutdown"] // #[rust_name = "shutdown"]

View File

@ -777,6 +777,9 @@ pub enum ValidationError {
InvalidImageContent(String), InvalidImageContent(String),
#[error("Could not fetch image: {0}")] #[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error), FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str)
} }
#[cfg(test)] #[cfg(test)]