(ffi) encode the provided user prompt within each request thread

This commit is contained in:
Morgan Funtowicz 2024-08-05 07:56:14 +00:00 committed by Morgan Funtowicz
parent 0b0c30fe8b
commit 933ab67aa1
3 changed files with 75 additions and 8 deletions

View File

@ -2,6 +2,7 @@ pub use looper::TensorRtLlmBackendV2;
pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {

View File

@ -6,7 +6,7 @@ use std::sync::OnceLock;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use tokenizers::Tokenizer;
use tokenizers::{Encoding, Tokenizer};
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::UnboundedReceiverStream;
@ -14,10 +14,12 @@ use tracing::{error, info, Level, span};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::infer::InferError::GenerationError;
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
use crate::utils::first_line;
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
@ -27,6 +29,11 @@ unsafe impl Send for TensorRtLlmBackendImpl {}
type InferResult<T> = Result<T, InferError>;
struct ValidGenerateRequestWithTokens {
encoding: Encoding,
inner: ValidGenerateRequest,
}
fn executor_status_poller(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
@ -47,12 +54,12 @@ fn executor_status_poller(
// Submit all the request to the executor and move the context to the in-flight tracker
for ctx in requests {
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let generation_params = &request.inner.parameters;
let stopping_params = &request.inner.stopping_parameters;
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&vec![],
request.encoding.get_ids(),
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
@ -110,7 +117,7 @@ fn executor_status_poller(
}
struct GenerationContext {
request: ValidGenerateRequest,
request: ValidGenerateRequestWithTokens,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
@ -147,7 +154,7 @@ impl TensorRtLlmBackendV2 {
// Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(e.what().to_string()))?;
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Looper is responsible for scheduling and pulling requests state at regular interval
let looper =
@ -159,15 +166,52 @@ impl TensorRtLlmBackendV2 {
queue: requests_sender,
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
request: ValidGenerateRequest,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
let prompt = Self::validate(&inner)?;
// We encode the prompt in every request context/thread
let encoding = self
.tokenizer
.encode(prompt.as_str(), true)
.map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?;
let request = ValidGenerateRequestWithTokens { encoding, inner };
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
match self.queue.send(GenerationContext { request, streamer }) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(

View File

@ -0,0 +1,22 @@
///
/// Extract the first line of the provided string reference.
/// If there is no lines in the buffer, it returns a string
/// which content is defined by the content of `fail`
/// # Arguments
///
/// * `s`: The string buffer to extract the first-line from
/// * `fail`: A string content which is returned if no lines are
/// present in `s`
///
/// returns: String
///
/// # Examples
///
/// ```
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
/// first_line(s, "No line in string");
/// ```
#[inline]
pub(crate) fn first_line(s: &str, fail: &str) -> String {
s.lines().next().unwrap_or(fail).to_string()
}