(ffi) encode the provided user prompt within each request thread
This commit is contained in:
parent
0b0c30fe8b
commit
933ab67aa1
|
@ -2,6 +2,7 @@ pub use looper::TensorRtLlmBackendV2;
|
|||
|
||||
pub mod errors;
|
||||
mod looper;
|
||||
mod utils;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
|
|
|
@ -6,7 +6,7 @@ use std::sync::OnceLock;
|
|||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use hashbrown::HashMap;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokenizers::{Encoding, Tokenizer};
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
@ -14,10 +14,12 @@ use tracing::{error, info, Level, span};
|
|||
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::infer::InferError::GenerationError;
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
use crate::utils::first_line;
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
|
@ -27,6 +29,11 @@ unsafe impl Send for TensorRtLlmBackendImpl {}
|
|||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
struct ValidGenerateRequestWithTokens {
|
||||
encoding: Encoding,
|
||||
inner: ValidGenerateRequest,
|
||||
}
|
||||
|
||||
fn executor_status_poller(
|
||||
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||
|
@ -47,12 +54,12 @@ fn executor_status_poller(
|
|||
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||
for ctx in requests {
|
||||
let request = &ctx.request;
|
||||
let generation_params = &request.parameters;
|
||||
let stopping_params = &request.stopping_parameters;
|
||||
let generation_params = &request.inner.parameters;
|
||||
let stopping_params = &request.inner.stopping_parameters;
|
||||
|
||||
// Submit to the TensorRT-LLM executor for scheduling
|
||||
match backend.pin_mut().submit(
|
||||
&vec![],
|
||||
request.encoding.get_ids(),
|
||||
stopping_params.max_new_tokens,
|
||||
generation_params.top_k as i32,
|
||||
generation_params.top_p,
|
||||
|
@ -110,7 +117,7 @@ fn executor_status_poller(
|
|||
}
|
||||
|
||||
struct GenerationContext {
|
||||
request: ValidGenerateRequest,
|
||||
request: ValidGenerateRequestWithTokens,
|
||||
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
}
|
||||
|
||||
|
@ -147,7 +154,7 @@ impl TensorRtLlmBackendV2 {
|
|||
|
||||
// Create the FFI backend
|
||||
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
|
||||
.map_err(|e| TensorRtLlmBackendError::Runtime(e.what().to_string()))?;
|
||||
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||
|
||||
// Looper is responsible for scheduling and pulling requests state at regular interval
|
||||
let looper =
|
||||
|
@ -159,15 +166,52 @@ impl TensorRtLlmBackendV2 {
|
|||
queue: requests_sender,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||
if request.top_n_tokens > 1 {
|
||||
return Err(InferError::ValidationError(
|
||||
ValidationError::TopNTokensDisabled,
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Is it really needed? How can it be validated before?
|
||||
if request.parameters.grammar.is_some() {
|
||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
||||
}
|
||||
|
||||
match request.inputs.len() {
|
||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
||||
2.. => Err(InferError::GenerationError(
|
||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||
)),
|
||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||
Chunk::Text(text) => Ok(text),
|
||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TensorRtLlmBackendV2 {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
inner: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
let prompt = Self::validate(&inner)?;
|
||||
|
||||
// We encode the prompt in every request context/thread
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(|e| GenerationError(format!("Tokenization failed {}", e.to_string())))?;
|
||||
|
||||
let request = ValidGenerateRequestWithTokens { encoding, inner };
|
||||
|
||||
// Open-up the stream to send tokens
|
||||
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||
|
||||
// Send the context to the executor for scheduling
|
||||
match self.queue.send(GenerationContext { request, streamer }) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||
Err(_) => Err(GenerationError(
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
///
|
||||
/// Extract the first line of the provided string reference.
|
||||
/// If there is no lines in the buffer, it returns a string
|
||||
/// which content is defined by the content of `fail`
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `s`: The string buffer to extract the first-line from
|
||||
/// * `fail`: A string content which is returned if no lines are
|
||||
/// present in `s`
|
||||
///
|
||||
/// returns: String
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
|
||||
/// first_line(s, "No line in string");
|
||||
/// ```
|
||||
#[inline]
|
||||
pub(crate) fn first_line(s: &str, fail: &str) -> String {
|
||||
s.lines().next().unwrap_or(fail).to_string()
|
||||
}
|
Loading…
Reference in New Issue