From e983ee5bb896adb5ea5b3672fffcad7bef1dbfb5 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 17 Jul 2024 21:56:50 +0000 Subject: [PATCH] make sure the context is not dropped in the middle of the async decoding. --- backends/trtllm/Cargo.toml | 14 +- backends/trtllm/include/backend.h | 3 +- backends/trtllm/include/ffi.h | 4 +- backends/trtllm/lib/backend.cpp | 9 +- backends/trtllm/src/backend.rs | 337 +++++++++++++++--------------- backends/trtllm/src/ffi.cpp | 8 +- backends/trtllm/src/lib.rs | 6 +- router/src/validation.rs | 3 + 8 files changed, 192 insertions(+), 192 deletions(-) diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 58c69bd4..f07c26d1 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -6,19 +6,19 @@ authors.workspace = true homepage.workspace = true [dependencies] -async-trait = "0.1.74" -async-stream = "0.3.5" +async-trait = "0.1" +async-stream = "0.3" cxx = "1.0" text-generation-router = { path = "../../router" } tokenizers = { version = "0.19", features = ["hf-hub"] } -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } -tokio-stream = "0.1.14" -clap = { version = "4.5.4", features = ["derive"] } -thiserror = "1.0.61" +tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.15" +clap = { version = "4.5", features = ["derive"] } +thiserror = "1.0.62" tracing = "0.1" tracing-opentelemetry = "0.24" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -log = { version = "0.4.21", features = [] } +log = { version = "0.4", features = [] } [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h index d84bc253..e8f00a38 100644 --- a/backends/trtllm/include/backend.h +++ b/backends/trtllm/include/backend.h @@ -50,8 +50,7 @@ namespace huggingface::tgi::backends { uint32_t topK, float_t topP, float_t temperature, - uint64_t seed, - std::optional beamWidth + uint64_t seed ); /** diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index 2ebacfdb..9895382f 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -56,8 +56,8 @@ namespace huggingface::tgi::backends { */ size_t StreamTokens( const RequestId requestId, - rust::Box ctx, - rust::Fn, uint32_t, float_t, bool)> callback); + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback); }; /*** diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp index 161dea5a..2b552113 100644 --- a/backends/trtllm/lib/backend.cpp +++ b/backends/trtllm/lib/backend.cpp @@ -57,10 +57,9 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( uint32_t topK, float_t topP, float_t temperature, - uint64_t seed, - std::optional beamWidth = std::nullopt) { + uint64_t seed) { return tle::SamplingConfig( - beamWidth.value_or(1), + 1, // TGI only use a single beam topK, topP, std::nullopt, @@ -116,11 +115,11 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( ); #endif - const auto maxNumTokens = config["max_num_tokens"_json_pointer].get(); + const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get(); const auto maxNewTokens = static_cast(std::max(1ul, maxNumTokens - tokens.size())); const auto sampling = GetSamplingConfig(topK, topP, temperature, seed); - const auto output = tle::OutputConfig(false, false, false, true, false); + const auto output = tle::OutputConfig(true, false, false, true, false); return executor.enqueueRequest( tle::Request{tokens, maxNewTokens, true, sampling, output}); } diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index d3f56ad9..f985e562 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -8,7 +8,7 @@ use std::time::Duration; use async_trait::async_trait; use cxx::UniquePtr; -use log::{info, warn}; +use log::{debug, info, warn}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::RwLock; @@ -19,7 +19,8 @@ use tracing::{instrument, Level, span}; use text_generation_router::{FinishReason, Token}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; -use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; +use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; @@ -55,10 +56,12 @@ pub(crate) struct Generation { done: Arc, } -pub struct GenerationContext( - UnboundedSender>, - Arc, -); +#[derive(Clone)] +pub struct GenerationContext { + sender: UnboundedSender>, + tokenizer: Arc, + done: Arc, +} impl Stream for Generation { type Item = usize; @@ -110,24 +113,160 @@ unsafe impl Sync for TensorRtLlmBackendImpl {} /// Implements the logic to execute generation with TensorRT-LLM executor API in background pub struct TensorRtLlmBackend { - // Allowing sending user requests to the TensorRT-LLM executor thread - // batcher: UnboundedSender, + tokenizer: Arc, backend: Arc>>, } impl TensorRtLlmBackend { pub fn new + Send + 'static, PP: AsRef + Send + 'static>( - _tokenizer: Tokenizer, + tokenizer: Tokenizer, engine_folder: P, _executor_worker_path: Option, ) -> Result { Ok(TensorRtLlmBackend { + tokenizer: Arc::new(tokenizer), backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( engine_folder.as_ref().to_str().unwrap(), "", ))), }) } + + fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + if request.top_n_tokens > 1 { + return Err(InferError::ValidationError( + ValidationError::TopNTokensDisabled, + )); + } + + match request.inputs.len() { + 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 2.. => Err(InferError::GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(text) => Ok(text), + Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + }, + } + } + + fn generate( + &self, + sender: UnboundedSender>, + tokens: Vec, + top_k: u32, + top_p: f32, + temperature: f32, + seed: u64, + ) { + let tokenizer = self.tokenizer.clone(); + let executor = self.backend.clone(); + + // Let's push this in async context + tokio::spawn(async move { + // Define the generation state + let mut generation = Generation { + executor: executor.clone(), + done: Arc::new(AtomicBool::new(false)), + }; + + // Define the context over the generation + // TODO(asap): Do we really need so many shared-ownership? + let ctx = Box::new(GenerationContext { + sender: sender.clone(), + tokenizer: tokenizer.clone(), + done: Arc::clone(&generation.done), + }); + + // We are leaking the context on-purpose to avoid the box being dropped while there are + // still computation ongoing + // TODO(asap): Can we achieve the same with an Arc> without the need to go unsafe? + let ctx_ = Box::leak(ctx); + + // Submit the request to the batcher + let request_id = span!(Level::DEBUG, "submit") + .in_scope(|| async { + debug!("Acquiring lock for submit"); + let mut handle = executor.write().await; + let request_id = + handle + .pin_mut() + .submit(&tokens, top_k as i32, top_p, temperature, seed); + + debug!("Releasing lock for submit"); + request_id + }) + .await; + + while let Some(_) = generation.next().await { + span!(Level::DEBUG, "decode", request_id = request_id) + .in_scope(|| async { + let mut executor_w = executor.write().await; + + unsafe { + debug!("Acquired write lock stream"); + executor_w.pin_mut().stream_tokens( + request_id, + ctx_, + |ctx: *mut GenerationContext, + token: u32, + logprob: f32, + is_final: bool| { + // let text = ctx + // .tokenizer + // .decode(&[token], true) + // .expect("Failed to decode token"); + info!("Decoded token: {}", token); + let out = if is_final { + (*ctx).done.store(true, Ordering::Relaxed); + InferStreamResponse::End { + token: Token { + id: token, + text: "".into(), + logprob, + special: false, + }, + top_tokens: vec![], + generated_text: GeneratedText { + text: "".into(), + generated_tokens: u32::MAX, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: Instant::now(), + queued: Instant::now(), + } + } else { + InferStreamResponse::Intermediate { + token: Token { + id: token, + text: "".into(), + logprob, + special: false, + }, + top_tokens: vec![], + } + }; + (*ctx) + .sender + .send(Ok(out)) + .expect("Failed to send back generated token"); + }, + ); + debug!("Releasing write lock stream") + } + }) + .await; + } + + // "Properly" free the shared context... + // TODO: clean that piece of sh** asap + unsafe { + let _ = Box::from_raw(ctx_); + } + }); + } } #[async_trait] @@ -135,96 +274,32 @@ impl Backend for TensorRtLlmBackend { #[instrument(skip_all)] fn schedule( &self, - _request: ValidGenerateRequest, + request: ValidGenerateRequest, ) -> InferResult>> { + // Let's add a few more validation + let input = TensorRtLlmBackend::validate(&request)?; + // Channel to stream the generated token as they come from the worker thread back to the transport layer let (sender, receiver) = unbounded_channel(); - let executor = self.backend.clone(); - tokio::spawn(async move { - // Submit the request to the batcher - let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]") - .in_scope(|| async { - info!("Acquiring lock for submit"); - let mut handle = executor.write().await; - let request_id = handle.pin_mut().submit( - &vec![2, 2926, 1503, 603, 20189], - 50, - 1.0, - 1.0, - 2014, - ); + // Unpack parameters + let params = &request.parameters; - info!("Releasing lock for submit"); - request_id - }) - .await; + // Preprocess the inputs to send to TRTLLM backend + let encoding = self + .tokenizer + .encode(input.as_str(), true) + .map_err(|e| InferError::GenerationError(e.to_string()))?; - let mut generation = Generation { - executor: executor.clone(), - done: Arc::new(AtomicBool::new(false)), - }; - - while let Some(num_tokens_ready) = generation.next().await { - span!( - Level::DEBUG, - "[EXECUTOR][GENERATE]", - request_id = request_id, - num_tokens_ready = num_tokens_ready - ) - .in_scope(|| async { - let ctx = Box::new(GenerationContext( - sender.clone(), - Arc::clone(&generation.done), - )); - let mut executor_w = executor.write().await; - - info!("Acquired write lock stream"); - executor_w.pin_mut().stream_tokens( - request_id, - ctx, - |ctx: Box, token: u32, logprob: f32, is_final: bool| { - info!("Sending token: {} (final: {})", token, is_final); - let out = if is_final { - ctx.1.store(true, Ordering::Relaxed); - InferStreamResponse::End { - token: Token { - id: token, - text: "".into(), - logprob, - special: false, - }, - top_tokens: vec![], - generated_text: GeneratedText { - text: "".into(), - generated_tokens: u32::MAX, - finish_reason: FinishReason::EndOfSequenceToken, - seed: None, - }, - start: Instant::now(), - queued: Instant::now(), - } - } else { - InferStreamResponse::Intermediate { - token: Token { - id: token, - text: "".into(), - logprob, - special: false, - }, - top_tokens: vec![], - } - }; - ctx.0 - .send(Ok(out)) - .expect("Failed to send back generated token"); - }, - ); - info!("Releasing write lock stream") - }) - .await; - } - }); + // Generate the response + self.generate( + sender, + Vec::from(encoding.get_ids()), + params.top_k, + params.top_p, + params.temperature, + params.seed, + ); Ok(UnboundedReceiverStream::new(receiver)) } @@ -233,79 +308,3 @@ impl Backend for TensorRtLlmBackend { true } } - -// async fn background_looper, PP: AsRef>( -// engine_folder: P, -// _executor_worker: Option, -// tokenizer: Tokenizer, -// mut receiver: UnboundedReceiver, -// ) { -// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), ""); -// -// while !(receiver.is_closed()) { -// // Receive the incoming request -// if let Some(ctx) = receiver.recv().await { -// debug!("Processing new incoming request"); -// -// // We only support single, textual chunk -// if ctx.request.inputs.len() != 1 { -// propagate!( -// ctx, -// Err(InferError::GenerationError(format!( -// "Unsupported multi-chunk ({}) input", -// ctx.request.inputs.len() -// ))) -// ); -// } -// -// let input = ctx -// .request -// .inputs -// .first() -// .expect("Single chunk checked above"); -// let params = ctx.request.parameters; -// } -// } - -// Receive the incoming request -// if let Some(ctx) = receiver.recv().await { -// debug!("Processing new incoming request"); - -// // We only support single, textual chunk -// if ctx.request.inputs.len() != 1 { -// propagate!( -// ctx, -// Err(InferError::GenerationError(format!( -// "Unsupported multi-chunk ({}) input", -// ctx.request.inputs.len() -// ))) -// ); -// } -// -// // Unpack parameters -// let inputs = ctx.request.inputs; -// let params = ctx.request.parameters; -// -// match inputs.first().unwrap() { -// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) { -// Err(err) => { -// propagate!(ctx, Err(InferError::GenerationError(err.to_string()))) -// } -// Ok(encoding) => { -// // spawn_blocking(|| { -// // info!("Submitting request to TensorRT-LLM executor"); -// // let mut executor = backend.blocking_write(); -// // }) -// // .await -// // .expect(""); -// } -// }, -// Chunk::Image(_) => propagate!( -// ctx, -// Err(InferError::GenerationError( -// "Image input is not supported yet.".into(), -// )) -// ), -// } -// }; -// } diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index 43d6c9f2..39c7104c 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -33,8 +33,8 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( const uint64_t requestId, - rust::Box ctx, - rust::Fn, uint32_t, float_t, bool)> callback) { + huggingface::tgi::backends::GenerationContext *ctx, + rust::Fn callback) { size_t numTokens = 0; for (const auto &item: Poll(requestId)) { @@ -44,12 +44,12 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( const auto token = decoded.outputTokenIds[0][0]; const auto isFinal = decoded.isFinal; - const auto logProb = decoded.logProbs.value()[0][0]; +// const auto logProb = decoded.logProbs.value()[0][0]; ++numTokens; SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); - callback(std::move(ctx), token, logProb, isFinal); + callback(std::move(ctx), token, 1.0, isFinal); SPDLOG_DEBUG("\tStreamTokens -> Post callback"); } else { // TODO : Return rest::Result with error diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index a2611c66..6506406d 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -54,11 +54,11 @@ mod ffi { ) -> u64; #[rust_name = "stream_tokens"] - fn StreamTokens( + unsafe fn StreamTokens( self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64, - ctx: Box, - cb: fn(Box, u32, f32, bool), + ctx: *mut GenerationContext, + cb: unsafe fn(*mut GenerationContext, u32, f32, bool), ) -> usize; // #[rust_name = "shutdown"] diff --git a/router/src/validation.rs b/router/src/validation.rs index 2a00b08e..97e1480c 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -777,6 +777,9 @@ pub enum ValidationError { InvalidImageContent(String), #[error("Could not fetch image: {0}")] FailedFetchImage(#[from] reqwest::Error), + #[error("{0} modality is not supported")] + UnsupportedModality(&'static str) + } #[cfg(test)]