From 933ab67aa1322aaa2062d001e6b7d5cb1cf6c39a Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 5 Aug 2024 07:56:14 +0000 Subject: [PATCH] (ffi) encode the provided user prompt within each request thread --- backends/trtllm/src/lib.rs | 1 + backends/trtllm/src/looper.rs | 60 ++++++++++++++++++++++++++++++----- backends/trtllm/src/utils.rs | 22 +++++++++++++ 3 files changed, 75 insertions(+), 8 deletions(-) create mode 100644 backends/trtllm/src/utils.rs diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 00a510a7..e6e97c03 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -2,6 +2,7 @@ pub use looper::TensorRtLlmBackendV2; pub mod errors; mod looper; +mod utils; #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 29866c2f..3db7d1ab 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -6,7 +6,7 @@ use std::sync::OnceLock; use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; -use tokenizers::Tokenizer; +use tokenizers::{Encoding, Tokenizer}; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -14,10 +14,12 @@ use tracing::{error, info, Level, span}; use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::infer::InferError::GenerationError; -use text_generation_router::validation::ValidGenerateRequest; +use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; +use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; +use crate::utils::first_line; // Value used to poll the state of the generation stream static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); @@ -27,6 +29,11 @@ unsafe impl Send for TensorRtLlmBackendImpl {} type InferResult = Result; +struct ValidGenerateRequestWithTokens { + encoding: Encoding, + inner: ValidGenerateRequest, +} + fn executor_status_poller( mut backend: UniquePtr, mut waiting_requests: UnboundedReceiver, @@ -47,12 +54,12 @@ fn executor_status_poller( // Submit all the request to the executor and move the context to the in-flight tracker for ctx in requests { let request = &ctx.request; - let generation_params = &request.parameters; - let stopping_params = &request.stopping_parameters; + let generation_params = &request.inner.parameters; + let stopping_params = &request.inner.stopping_parameters; // Submit to the TensorRT-LLM executor for scheduling match backend.pin_mut().submit( - &vec![], + request.encoding.get_ids(), stopping_params.max_new_tokens, generation_params.top_k as i32, generation_params.top_p, @@ -110,7 +117,7 @@ fn executor_status_poller( } struct GenerationContext { - request: ValidGenerateRequest, + request: ValidGenerateRequestWithTokens, streamer: UnboundedSender>, } @@ -147,7 +154,7 @@ impl TensorRtLlmBackendV2 { // Create the FFI backend let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path) - .map_err(|e| TensorRtLlmBackendError::Runtime(e.what().to_string()))?; + .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; // Looper is responsible for scheduling and pulling requests state at regular interval let looper = @@ -159,15 +166,52 @@ impl TensorRtLlmBackendV2 { queue: requests_sender, }) } + + fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { + if request.top_n_tokens > 1 { + return Err(InferError::ValidationError( + ValidationError::TopNTokensDisabled, + )); + } + + // TODO: Is it really needed? How can it be validated before? + if request.parameters.grammar.is_some() { + return Err(InferError::ValidationError(ValidationError::Grammar)); + } + + match request.inputs.len() { + 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 2.. => Err(InferError::GenerationError( + "TensorRT-LLM backend don't support multi-chunk".into(), + )), + 1 => match request.inputs.first().expect("Single item-chunk") { + Chunk::Text(text) => Ok(text), + Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))), + }, + } + } } #[async_trait] impl Backend for TensorRtLlmBackendV2 { fn schedule( &self, - request: ValidGenerateRequest, + inner: ValidGenerateRequest, ) -> Result>, InferError> { + let prompt = Self::validate(&inner)?; + + // We encode the prompt in every request context/thread + let encoding = self + .tokenizer + .encode(prompt.as_str(), true) + .map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?; + + let request = ValidGenerateRequestWithTokens { encoding, inner }; + + // Open-up the stream to send tokens let (streamer, receiver) = unbounded_channel::>(); + + // Send the context to the executor for scheduling match self.queue.send(GenerationContext { request, streamer }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError( diff --git a/backends/trtllm/src/utils.rs b/backends/trtllm/src/utils.rs new file mode 100644 index 00000000..4dedb007 --- /dev/null +++ b/backends/trtllm/src/utils.rs @@ -0,0 +1,22 @@ +/// +/// Extract the first line of the provided string reference. +/// If there is no lines in the buffer, it returns a string +/// which content is defined by the content of `fail` +/// # Arguments +/// +/// * `s`: The string buffer to extract the first-line from +/// * `fail`: A string content which is returned if no lines are +/// present in `s` +/// +/// returns: String +/// +/// # Examples +/// +/// ``` +/// let s = "My name is Morgan.\n I'm working at Hugging Face."; +/// first_line(s, "No line in string"); +/// ``` +#[inline] +pub(crate) fn first_line(s: &str, fail: &str) -> String { + s.lines().next().unwrap_or(fail).to_string() +}