From 6687c06a21e76d40bdf14911b235116c269edf4d Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 18 Oct 2024 00:09:45 +0200 Subject: [PATCH] feat(looper): minor optimizations to avoid growing too much the containers --- backends/trtllm/src/looper.rs | 126 ++++++++++++++++++++-------------- backends/trtllm/src/main.rs | 15 ++-- 2 files changed, 82 insertions(+), 59 deletions(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index b6a18ca2..beae8e8e 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -79,11 +79,13 @@ struct DecodedTokenContext { fn executor_status_looper( mut backend: UniquePtr, + max_inflight_requests: usize, mut waiting_requests: UnboundedReceiver, post_processor_sender: UnboundedSender<(u64, InferResult)>, ) { // Track the tuple (request_id, stream) for each request - let mut in_flights = HashMap::::with_capacity(128); + let mut in_flights = + HashMap::::with_capacity(max_inflight_requests * 2); // TODO: Does it need a spin-loop? 'scheduler: loop { @@ -169,9 +171,11 @@ fn executor_status_looper( fn post_processor_looper( tokenizer: Tokenizer, + max_num_tokens: usize, + max_inflight_requests: usize, mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>, ) { - let mut states: HashMap> = HashMap::with_capacity(128); + let mut states: HashMap> = HashMap::with_capacity(max_inflight_requests * 2); 'post_processor: loop { if decoded_tokens.is_closed() { @@ -182,11 +186,14 @@ fn post_processor_looper( if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() { match decoded { Ok(ctx) => { - states.entry(request_id).and_modify(|s| s.push(*&ctx.token.id)).or_insert_with(|| { - let mut state = Vec::with_capacity(128); - state.push(*&ctx.token.id); - state - }); + states + .entry(request_id) + .and_modify(|s| s.push(*&ctx.token.id)) + .or_insert_with(|| { + let mut state = Vec::with_capacity(max_num_tokens); + state.push(*&ctx.token.id); + state + }); let out = match tokenizer.decode(&[ctx.token.id], false) { Ok(text) => { @@ -232,12 +239,53 @@ fn post_processor_looper( warn!("Failed to send decoded token back to the user") } } - Err(err) => {} + Err(_err) => { + todo!("what do we do?") + } } } } } +fn ensure_paths_exist, PP: AsRef>( + engine_folder: P, + executor_worker_path: PP, +) -> Result<(String, String), TensorRtLlmBackendError> { + // Retrieve paths as &str for the backend creation + let engine_folder = engine_folder.as_ref(); + let executor_worker_path = executor_worker_path.as_ref(); + + // Ensure the engine folder exists + if !engine_folder.exists() { + let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + + // Ensure executor worker binary exists + if !executor_worker_path.exists() { + let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + + let engine_folder = String::from( + engine_folder + .to_str() + .expect("Failed to convert engine_folder to valid UTF-8"), + ); + + let executor_worker_path = String::from( + executor_worker_path + .to_str() + .expect("Failed to convert executor_worker_path to valid UTF-8"), + ); + + Ok((engine_folder, executor_worker_path)) +} + unsafe impl Send for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2 { @@ -252,49 +300,10 @@ impl TensorRtLlmBackendV2 { tokenizer: Tokenizer, engine_folder: P, executor_worker_path: PP, + max_inflight_requests: usize, ) -> Result { - // Retrieve paths as &str for the backend creation - let engine_folder = engine_folder.as_ref(); - let executor_worker_path = executor_worker_path.as_ref(); - - // Ensure the engine folder exists - if !engine_folder.exists() { - let err = - TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf()); - - error!( - err, - engine_folder = engine_folder.display(), - executor_worker_path = executor_worker_path.display() - ); - - return Err(err); - } - - // Ensure executor worker binary exists - if !executor_worker_path.exists() { - let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf()); - - error!( - err, - engine_folder = engine_folder.display(), - executor_worker_path = executor_worker_path.display() - ); - - return Err(err); - } - - let engine_folder = String::from( - engine_folder - .to_str() - .expect("Failed to convert engine_folder to valid UTF-8"), - ); - - let executor_worker_path = String::from( - executor_worker_path - .to_str() - .expect("Failed to convert executor_worker_path to valid UTF-8"), - ); + let (engine_folder, executor_worker_path) = + ensure_paths_exist(engine_folder, executor_worker_path)?; // Allocate the IPC layer to communicate with the backend let (executor_sender, executor_receiver) = unbounded_channel(); @@ -306,13 +315,24 @@ impl TensorRtLlmBackendV2 { // Executor looper is responsible for scheduling and pulling requests state at regular interval let executor_looper = spawn_blocking(move || { - executor_status_looper(backend, executor_receiver, post_processor_sender) + executor_status_looper( + backend, + max_inflight_requests, + executor_receiver, + post_processor_sender, + ) }); // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user let tokenizer_ = tokenizer.clone(); - let post_processor_looper = - spawn_blocking(move || post_processor_looper(tokenizer_, post_processor_receiver)); + let post_processor_looper = spawn_blocking(move || { + post_processor_looper( + tokenizer_, + 512, + max_inflight_requests, + post_processor_receiver, + ) + }); Ok(TensorRtLlmBackendV2 { tokenizer, diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index e78134b9..92712988 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,17 +1,15 @@ use std::path::{Path, PathBuf}; use clap::Parser; -use hf_hub::{Cache, Repo, RepoType}; use hf_hub::api::tokio::{Api, ApiBuilder}; +use hf_hub::{Cache, Repo, RepoType}; use tokenizers::Tokenizer; use tracing::info; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackendV2; -use text_generation_router::{HubTokenizerConfig, server}; -use text_generation_router::server::{ - create_post_processor, get_base_tokenizer, -}; +use text_generation_router::server::{create_post_processor, get_base_tokenizer}; +use text_generation_router::{server, HubTokenizerConfig}; /// App Configuration #[derive(Parser, Debug)] @@ -282,7 +280,12 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { .expect("Failed to retrieve tokenizer implementation"); info!("Successfully retrieved tokenizer {}", &tokenizer_name); - let backend = TensorRtLlmBackendV2::new(tokenizer, model_id, executor_worker)?; + let backend = TensorRtLlmBackendV2::new( + tokenizer, + model_id, + executor_worker, + max_concurrent_requests, + )?; info!("Successfully created backend");