From 7784a21d48a3a53669498d2c96f66a3965999411 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 16 Jul 2024 20:08:10 +0000 Subject: [PATCH] impl RwLock scenario for TensorRtLllmBackend --- backends/trtllm/Cargo.toml | 4 + backends/trtllm/build.rs | 2 + backends/trtllm/cmake/trtllm.cmake | 9 +- backends/trtllm/include/ffi.h | 20 +- backends/trtllm/src/backend.rs | 385 ++++++++++++++++++++--------- backends/trtllm/src/ffi.cpp | 52 ++-- backends/trtllm/src/lib.rs | 24 +- backends/trtllm/src/main.rs | 29 ++- 8 files changed, 352 insertions(+), 173 deletions(-) diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 49b8830d..58c69bd4 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -15,6 +15,10 @@ tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot" tokio-stream = "0.1.14" clap = { version = "4.5.4", features = ["derive"] } thiserror = "1.0.61" +tracing = "0.1" +tracing-opentelemetry = "0.24" +tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } +log = { version = "0.4.21", features = [] } [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs index 6736fe0c..f8f2a2c0 100644 --- a/backends/trtllm/build.rs +++ b/backends/trtllm/build.rs @@ -33,6 +33,8 @@ fn main() { "debug" => format!("{}d", dependency), _ => String::from(dependency), }; + let dep_path = deps_folder.join(format!("{}-build", dependency)); + println!("cargo:rustc-link-search={}", dep_path.display()); println!("cargo:rustc-link-lib=static={}", dep_name); } diff --git a/backends/trtllm/cmake/trtllm.cmake b/backends/trtllm/cmake/trtllm.cmake index 342fe4fa..c3042bb6 100644 --- a/backends/trtllm/cmake/trtllm.cmake +++ b/backends/trtllm/cmake/trtllm.cmake @@ -17,14 +17,11 @@ else () set(FAST_BUILD OFF) endif () -# This line turn off DEBUG in TRTLLM logger which is quite spammy -add_compile_definitions(NDEBUG OFF) - fetchcontent_declare( trtllm - GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git - GIT_TAG a96cccafcf6365c128f004f779160951f8c0801c - GIT_SHALLOW TRUE + GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git + GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c + GIT_SHALLOW FALSE ) fetchcontent_makeavailable(trtllm) message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index 6b77b7b4..2ebacfdb 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -5,7 +5,7 @@ #ifndef TGI_TRTLLM_BACKEND_FFI_H #define TGI_TRTLLM_BACKEND_FFI_H -//#include "rust/cxx.h" +#include #include "backend.h" namespace huggingface::tgi::backends { @@ -17,9 +17,9 @@ namespace huggingface::tgi::backends { namespace huggingface::tgi::backends { - struct GenerationContext; +// struct GenerationContext; - class TensorRtLlmBackendImpl : TensorRtLlmBackend { + class TensorRtLlmBackendImpl : public TensorRtLlmBackend { public: /*** * @@ -37,7 +37,6 @@ namespace huggingface::tgi::backends { /*** * * @param tokens - * @param maxNewTokens * @param topK * @param topP * @param temperature @@ -45,17 +44,20 @@ namespace huggingface::tgi::backends { * @return */ [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] - uint64_t Submit(rust::Slice tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed); + uint64_t + Submit(rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed); /*** * * @param requestId - * @param handler + * @param ctx + * @param callback * @return */ - uint32_t Stream(rust::Box ctx, - uint64_t requestId, - rust::Fn, uint32_t, uint32_t, bool)> handler); + size_t StreamTokens( + const RequestId requestId, + rust::Box ctx, + rust::Fn, uint32_t, float_t, bool)> callback); }; /*** diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index eec4e081..b59e2006 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -1,160 +1,311 @@ -use std::cell::RefCell; +use std::future::Future; use std::path::Path; +use std::pin::{pin, Pin}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; +use std::time::Duration; use async_trait::async_trait; use cxx::UniquePtr; +use log::{info, warn}; use tokenizers::Tokenizer; -use tokio::sync::mpsc; -use tokio::time::Instant; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::sync::RwLock; +use tokio::time::{Instant, sleep}; +use tokio_stream::{Stream, StreamExt}; use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{instrument, Level, span}; use text_generation_router::{FinishReason, Token}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; -use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters}; +use text_generation_router::validation::ValidGenerateRequest; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; +// macro_rules! propagate { +// ($ctx: expr, $res: expr) => { +// $ctx.sender +// .send($res) +// .expect("Failed to propagate error back to the transport layer") +// }; +// } + type InferResult = Result; -pub struct GenerationContext(mpsc::UnboundedSender>); +/// Holds the user provided input to be executed along with a channel allowing +/// to bubble up all the generated tokens for that tokens the to end stream. +// pub struct InferenceContext { +// /// User provided request +// request: ValidGenerateRequest, +// +// /// Inter-process communication handler moving token from the executor thread to the HTTP server +// sender: UnboundedSender>, +// +// /// Pin the instant this inference context was submitted +// when: Instant, +// +// /// Span that will live as long as entry +// span: Span, +// } -pub struct TrtLLmBackend { - tokenizer: Tokenizer, - inner: RefCell>, +pub(crate) struct Generation { + executor: Arc>>, + done: Arc, } -unsafe impl Sync for TrtLLmBackend {} -unsafe impl Send for TrtLLmBackend {} +pub struct GenerationContext( + UnboundedSender>, + Arc, +); -impl TrtLLmBackend { - pub fn new>( - tokenizer: Tokenizer, - engine_folder: P, - ) -> Result { - let engine_folder = engine_folder.as_ref(); - let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), ""); +impl Stream for Generation { + type Item = usize; - Ok(Self { - tokenizer, - inner: RefCell::new(inner), - }) + fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { + if self.done.load(Ordering::Relaxed) { + Poll::Ready(None) + } else { + let pinned = pin!(self.executor.read()); + match pinned.poll(ctx) { + Poll::Ready(executor_r) => { + let ready = executor_r.num_responses_ready(); + if ready == 0 { + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_millis(10)).await; + waker.wake(); + }); + Poll::Pending + } else { + info!("Ready: {}", ready); + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_millis(100)).await; + waker.wake(); + }); + Poll::Ready(Some(ready)) + } + } + Poll::Pending => { + let waker = ctx.waker().clone(); + tokio::spawn(async { + sleep(Duration::from_millis(100)).await; + waker.wake(); + }); + Poll::Pending + } + } + } } - fn infer_text( - &self, - ctx: GenerationContext, - text: &str, - params: ValidParameters, - ) -> InferResult<()> { - // Keep track of processing time - let start = Instant::now(); + fn size_hint(&self) -> (usize, Option) { + (1, None) + } +} - // Encode the input - let ctx = Box::new(ctx); - let encoding = self - .tokenizer - .encode(text, true) - .map_err(|e| InferError::ToolError(e.to_string()))?; +unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} - // Submit the request to the backend and retrieve the handle to query its status - let request_id = self - .inner - .borrow_mut() - .as_mut() - .expect("Failed to retrieve pointer to TRTLLM backend") - .submit( - encoding.get_ids(), - 128, - params.top_k as i32, - params.top_p, - params.temperature, - params.seed, - ); +/// Implements the logic to execute generation with TensorRT-LLM executor API in background +pub struct TensorRtLlmBackend { + // Allowing sending user requests to the TensorRT-LLM executor thread + // batcher: UnboundedSender, + backend: Arc>>, +} - // Stream generated tokens - // spawn_blocking(move || { - let num_generated_tokens = self - .inner - .borrow_mut() - .as_mut() - .expect("Failed to retrieve pointer to TRTLLM backend") - .stream(ctx, request_id, |ctx, token, step, is_final| { - // self.tokenizer.decode(&*[token], true).unwrap(); - let sender = ctx.0; - let token = Token { - id: token, - text: String::from(""), - logprob: 1.0f32, - special: false, - }; - - sender - .send(Ok(InferStreamResponse::Intermediate { - token, - top_tokens: vec![], - })) - .unwrap() - }); - - // Notify the end - let _ = ctx.0.send(Ok(InferStreamResponse::End { - token: Token { - id: 0, - text: String::from(""), - logprob: 1.0f32, - special: false, - }, - top_tokens: vec![], - generated_text: GeneratedText { - text: String::from(""), - generated_tokens: num_generated_tokens, - finish_reason: FinishReason::EndOfSequenceToken, - seed: Some(params.seed), - }, - start, - queued: Instant::now(), - })); - // }); - - Ok(()) +impl TensorRtLlmBackend { + pub fn new + Send + 'static, PP: AsRef + Send + 'static>( + _tokenizer: Tokenizer, + engine_folder: P, + _executor_worker_path: Option, + ) -> Result { + Ok(TensorRtLlmBackend { + backend: Arc::new(RwLock::new(create_tensorrt_llm_backend( + engine_folder.as_ref().to_str().unwrap(), + "", + ))), + }) } } #[async_trait] -impl Backend for TrtLLmBackend { +impl Backend for TensorRtLlmBackend { + #[instrument(skip_all)] fn schedule( &self, - request: ValidGenerateRequest, + _request: ValidGenerateRequest, ) -> InferResult>> { - let (sender, receiver) = mpsc::unbounded_channel(); - let ctx = GenerationContext(sender); + // Channel to stream the generated token as they come from the worker thread back to the transport layer + let (sender, receiver) = unbounded_channel(); - // Unpack parameters - let params = request.parameters; + let executor = self.backend.clone(); + tokio::spawn(async move { + // Submit the request to the batcher + let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]") + .in_scope(|| async { + info!("Acquiring lock for submit"); + let mut handle = executor.write().await; + let request_id = handle.pin_mut().submit( + &vec![2, 2926, 1503, 603, 20189], + 50, + 1.0, + 1.0, + 2014, + ); - // Ensure we are running in the right conditions for the input (i.e. single textual chunk) - let input = match request.inputs.len() { - 0 => Err(InferError::GenerationError("No input provided".into())), - 1 => Ok(request.inputs.first().unwrap()), - _ => Err(InferError::GenerationError(format!( - "Unsupported multi-chunks ({}) inference.", - request.inputs.len() - ))), - }?; + info!("Releasing lock for submit"); + return request_id; + }) + .await; - // Currently we handle single chunk of text - match input { - Chunk::Text(text) => { - self.infer_text(ctx, &**text, params)?; + let mut generation = Generation { + executor: executor.clone(), + done: Arc::new(AtomicBool::new(false)), + }; + + while let Some(num_tokens_ready) = generation.next().await { + span!( + Level::DEBUG, + "[EXECUTOR][GENERATE]", + request_id = request_id, + num_tokens_ready = num_tokens_ready + ) + .in_scope(|| async { + let ctx = Box::new(GenerationContext( + sender.clone(), + Arc::clone(&generation.done), + )); + let mut executor_w = executor.write().await; + + info!("Acquired write lock stream"); + executor_w.pin_mut().stream_tokens( + request_id, + ctx, + |ctx: Box, token: u32, logprob: f32, is_final: bool| { + info!("Sending token: {} (final: {})", token, is_final); + let out = if is_final { + ctx.1.store(true, Ordering::Relaxed); + InferStreamResponse::End { + token: Token { + id: token, + text: "".into(), + logprob, + special: false, + }, + top_tokens: vec![], + generated_text: GeneratedText { + text: "".into(), + generated_tokens: u32::MAX, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: Instant::now(), + queued: Instant::now(), + } + } else { + InferStreamResponse::Intermediate { + token: Token { + id: token, + text: "".into(), + logprob, + special: false, + }, + top_tokens: vec![], + } + }; + ctx.0 + .send(Ok(out)) + .expect("Failed to send back generated token"); + }, + ); + info!("Releasing write lock stream") + }) + .await; } - Chunk::Image(_) => panic!("Unsupported"), - }; + }); Ok(UnboundedReceiverStream::new(receiver)) } async fn health(&self, _current_health: bool) -> bool { - self.inner.borrow_mut().is_ready() + true } } + +// async fn background_looper, PP: AsRef>( +// engine_folder: P, +// _executor_worker: Option, +// tokenizer: Tokenizer, +// mut receiver: UnboundedReceiver, +// ) { +// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), ""); +// +// while !(receiver.is_closed()) { +// // Receive the incoming request +// if let Some(ctx) = receiver.recv().await { +// debug!("Processing new incoming request"); +// +// // We only support single, textual chunk +// if ctx.request.inputs.len() != 1 { +// propagate!( +// ctx, +// Err(InferError::GenerationError(format!( +// "Unsupported multi-chunk ({}) input", +// ctx.request.inputs.len() +// ))) +// ); +// } +// +// let input = ctx +// .request +// .inputs +// .first() +// .expect("Single chunk checked above"); +// let params = ctx.request.parameters; +// } +// } + +// Receive the incoming request +// if let Some(ctx) = receiver.recv().await { +// debug!("Processing new incoming request"); + +// // We only support single, textual chunk +// if ctx.request.inputs.len() != 1 { +// propagate!( +// ctx, +// Err(InferError::GenerationError(format!( +// "Unsupported multi-chunk ({}) input", +// ctx.request.inputs.len() +// ))) +// ); +// } +// +// // Unpack parameters +// let inputs = ctx.request.inputs; +// let params = ctx.request.parameters; +// +// match inputs.first().unwrap() { +// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) { +// Err(err) => { +// propagate!(ctx, Err(InferError::GenerationError(err.to_string()))) +// } +// Ok(encoding) => { +// // spawn_blocking(|| { +// // info!("Submitting request to TensorRT-LLM executor"); +// // let mut executor = backend.blocking_write(); +// // }) +// // .await +// // .expect(""); +// } +// }, +// Chunk::Image(_) => propagate!( +// ctx, +// Err(InferError::GenerationError( +// "Image input is not supported yet.".into(), +// )) +// ), +// } +// }; +// } diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index 47e73a6f..2920eda0 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -7,6 +7,7 @@ #include #include +#include #include "backends/trtllm/include/ffi.h" @@ -21,42 +22,43 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const { } uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( - rust::Slice tokens, - int32_t maxNewTokens, int32_t topK, float_t topP, - float_t temperature, uint64_t seed) { + rust::Slice tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) { // This will copy all the items from the initial slice std::vector tokens_(tokens.size()); tokens_.assign(tokens.begin(), tokens.end()); - return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed); + return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed); } -uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream( - rust::Box ctx, - uint64_t requestId, - rust::Fn, uint32_t, uint32_t, bool)> handler) { - bool isDone = false; - uint32_t numGeneratedTokens = 0; +size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(const uint64_t requestId, + rust::Box ctx, + rust::Fn, uint32_t, float_t, bool)> callback) { - do { - const auto responses = Poll(requestId); - for (const auto &response: responses) { - if (response.hasError()) { - isDone = true; - // TODO : bubble up the error to rust - } else { - const auto generation = response.getResult(); - const auto token = generation.outputTokenIds[0][0]; - isDone = generation.isFinal; + SPDLOG_INFO("Entering StreamTokens"); + for (const auto &item: Poll(requestId)) { + if (!item.hasError()) { + SPDLOG_INFO("\tStreamTokens -> Decoding token..."); + const auto decoded = item.getResult(); + SPDLOG_INFO("\tStreamTokens -> Successfully read decoded token ({})", decoded.outputTokenIds[0].size()); - // Propagate through the handler - handler(std::move(ctx), token, numGeneratedTokens, isDone); - } + const auto token = decoded.outputTokenIds[0][0]; + const auto isFinal = decoded.isFinal; +// const auto logProb = decoded.logProbs.value()[0][0]; + const auto logProb = 0.0; + + SPDLOG_INFO(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); + callback(std::move(ctx), token, logProb, isFinal); + SPDLOG_INFO("\tStreamTokens -> Post callback"); + } else { + // TODO : Return rest::Result with error + SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg()); + callback(std::move(ctx), 0, 0.0, true); } - } while (!isDone); + } - return numGeneratedTokens; + SPDLOG_INFO("Exiting StreamTokens"); + return 0; } std::unique_ptr diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index ef4b2907..a2611c66 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -17,7 +17,7 @@ mod ffi { /// Represent an instance of the underlying TensorRT-LLM backend type TensorRtLlmBackendImpl; - /// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend + /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend /// /// # Arguments /// @@ -37,29 +37,31 @@ mod ffi { executor_worker: &str, ) -> UniquePtr; - #[rust_name = "is_ready"] - fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + // #[rust_name = "is_ready"] + // fn IsReady(self: &TensorRtLlmBackendImpl) -> bool; + + #[rust_name = "num_responses_ready"] + fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; #[rust_name = "submit"] fn Submit( self: Pin<&mut TensorRtLlmBackendImpl>, tokens: &[u32], - max_new_tokens: i32, top_k: i32, top_p: f32, temperature: f32, seed: u64, ) -> u64; - #[rust_name = "stream"] - fn Stream( + #[rust_name = "stream_tokens"] + fn StreamTokens( self: Pin<&mut TensorRtLlmBackendImpl>, - ctx: Box, request_id: u64, - callback: fn(Box, u32, u32, bool), - ) -> u32; + ctx: Box, + cb: fn(Box, u32, f32, bool), + ) -> usize; - #[rust_name = "shutdown"] - fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); + // #[rust_name = "shutdown"] + // fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>); } } diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index a871a4c5..5d989feb 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,9 +1,11 @@ use std::collections::HashMap; +use std::path::PathBuf; use clap::Parser; use tokenizers::{FromPretrainedParameters, Tokenizer}; -use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend}; +use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; +use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_router::server; /// App Configuration @@ -53,7 +55,13 @@ struct Args { #[clap(default_value = "4", long, env)] max_client_batch_size: usize, #[clap(long, env)] - auth_token: Option + auth_token: Option, + #[clap( + long, + env, + help = "Path to the TensorRT-LLM Orchestrator Worker binary" + )] + executor_worker: Option, } #[tokio::main] @@ -83,7 +91,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { cors_allow_origin, messages_api_enabled, max_client_batch_size, - auth_token + auth_token, + executor_worker, } = args; // Launch Tokio runtime @@ -114,6 +123,15 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { } } + if let Some(ref executor_worker) = executor_worker { + if !executor_worker.exists() { + return Err(TensorRtLlmBackendError::ArgumentValidation(format!( + "`executor_work` specified path doesn't exists: {}", + executor_worker.display() + ))); + } + } + // Run server let tokenizer = Tokenizer::from_pretrained( tokenizer_name.clone(), @@ -122,9 +140,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { user_agent: HashMap::new(), auth_token, }), - ).map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + ) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; - let backend = TrtLLmBackend::new(tokenizer, model_id)?; + let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; server::run( backend, max_concurrent_requests,