impl RwLock scenario for TensorRtLllmBackend

This commit is contained in:
Morgan Funtowicz 2024-07-16 20:08:10 +00:00
parent 31d9f4d5dc
commit 7784a21d48
8 changed files with 352 additions and 173 deletions

View File

@ -15,6 +15,10 @@ tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot"
tokio-stream = "0.1.14"
clap = { version = "4.5.4", features = ["derive"] }
thiserror = "1.0.61"
tracing = "0.1"
tracing-opentelemetry = "0.24"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
log = { version = "0.4.21", features = [] }
[build-dependencies]
cmake = "0.1"

View File

@ -33,6 +33,8 @@ fn main() {
"debug" => format!("{}d", dependency),
_ => String::from(dependency),
};
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}

View File

@ -17,14 +17,11 @@ else ()
set(FAST_BUILD OFF)
endif ()
# This line turn off DEBUG in TRTLLM logger which is quite spammy
add_compile_definitions(NDEBUG OFF)
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git
GIT_TAG a96cccafcf6365c128f004f779160951f8c0801c
GIT_SHALLOW TRUE
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c
GIT_SHALLOW FALSE
)
fetchcontent_makeavailable(trtllm)
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")

View File

@ -5,7 +5,7 @@
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
//#include "rust/cxx.h"
#include <cstddef>
#include "backend.h"
namespace huggingface::tgi::backends {
@ -17,9 +17,9 @@ namespace huggingface::tgi::backends {
namespace huggingface::tgi::backends {
struct GenerationContext;
// struct GenerationContext;
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public:
/***
*
@ -37,7 +37,6 @@ namespace huggingface::tgi::backends {
/***
*
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
@ -45,17 +44,20 @@ namespace huggingface::tgi::backends {
* @return
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t Submit(rust::Slice<const uint32_t> tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
/***
*
* @param requestId
* @param handler
* @param ctx
* @param callback
* @return
*/
uint32_t Stream(rust::Box <GenerationContext> ctx,
uint64_t requestId,
rust::Fn<void(rust::Box<GenerationContext>, uint32_t, uint32_t, bool)> handler);
size_t StreamTokens(
const RequestId requestId,
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback);
};
/***

View File

@ -1,160 +1,311 @@
use std::cell::RefCell;
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{info, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc;
use tokio::time::Instant;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{Instant, sleep};
use tokio_stream::{Stream, StreamExt};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{instrument, Level, span};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
use text_generation_router::validation::ValidGenerateRequest;
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
// macro_rules! propagate {
// ($ctx: expr, $res: expr) => {
// $ctx.sender
// .send($res)
// .expect("Failed to propagate error back to the transport layer")
// };
// }
type InferResult<T> = Result<T, InferError>;
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
// pub struct InferenceContext {
// /// User provided request
// request: ValidGenerateRequest,
//
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
//
// /// Pin the instant this inference context was submitted
// when: Instant,
//
// /// Span that will live as long as entry
// span: Span,
// }
pub struct TrtLLmBackend {
tokenizer: Tokenizer,
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>,
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
unsafe impl Sync for TrtLLmBackend {}
unsafe impl Send for TrtLLmBackend {}
pub struct GenerationContext(
UnboundedSender<InferResult<InferStreamResponse>>,
Arc<AtomicBool>,
);
impl TrtLLmBackend {
pub fn new<P: AsRef<Path>>(
tokenizer: Tokenizer,
engine_folder: P,
) -> Result<Self, TensorRtLlmBackendError> {
let engine_folder = engine_folder.as_ref();
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
impl Stream for Generation {
type Item = usize;
Ok(Self {
tokenizer,
inner: RefCell::new(inner),
})
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done.load(Ordering::Relaxed) {
Poll::Ready(None)
} else {
let pinned = pin!(self.executor.read());
match pinned.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(10)).await;
waker.wake();
});
Poll::Pending
} else {
info!("Ready: {}", ready);
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Ready(Some(ready))
}
}
Poll::Pending => {
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_millis(100)).await;
waker.wake();
});
Poll::Pending
}
}
}
}
fn infer_text(
&self,
ctx: GenerationContext,
text: &str,
params: ValidParameters,
) -> InferResult<()> {
// Keep track of processing time
let start = Instant::now();
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
// Encode the input
let ctx = Box::new(ctx);
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| InferError::ToolError(e.to_string()))?;
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
// Submit the request to the backend and retrieve the handle to query its status
let request_id = self
.inner
.borrow_mut()
.as_mut()
.expect("Failed to retrieve pointer to TRTLLM backend")
.submit(
encoding.get_ids(),
128,
params.top_k as i32,
params.top_p,
params.temperature,
params.seed,
);
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
// Allowing sending user requests to the TensorRT-LLM executor thread
// batcher: UnboundedSender<InferenceContext>,
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
// Stream generated tokens
// spawn_blocking(move || {
let num_generated_tokens = self
.inner
.borrow_mut()
.as_mut()
.expect("Failed to retrieve pointer to TRTLLM backend")
.stream(ctx, request_id, |ctx, token, step, is_final| {
// self.tokenizer.decode(&*[token], true).unwrap();
let sender = ctx.0;
let token = Token {
id: token,
text: String::from(""),
logprob: 1.0f32,
special: false,
};
sender
.send(Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}))
.unwrap()
});
// Notify the end
let _ = ctx.0.send(Ok(InferStreamResponse::End {
token: Token {
id: 0,
text: String::from(""),
logprob: 1.0f32,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: String::from(""),
generated_tokens: num_generated_tokens,
finish_reason: FinishReason::EndOfSequenceToken,
seed: Some(params.seed),
},
start,
queued: Instant::now(),
}));
// });
Ok(())
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
_tokenizer: Tokenizer,
engine_folder: P,
_executor_worker_path: Option<PP>,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
"",
))),
})
}
}
#[async_trait]
impl Backend for TrtLLmBackend {
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
_request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
let (sender, receiver) = mpsc::unbounded_channel();
let ctx = GenerationContext(sender);
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = request.parameters;
let executor = self.backend.clone();
tokio::spawn(async move {
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]")
.in_scope(|| async {
info!("Acquiring lock for submit");
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&vec![2, 2926, 1503, 603, 20189],
50,
1.0,
1.0,
2014,
);
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
let input = match request.inputs.len() {
0 => Err(InferError::GenerationError("No input provided".into())),
1 => Ok(request.inputs.first().unwrap()),
_ => Err(InferError::GenerationError(format!(
"Unsupported multi-chunks ({}) inference.",
request.inputs.len()
))),
}?;
info!("Releasing lock for submit");
return request_id;
})
.await;
// Currently we handle single chunk of text
match input {
Chunk::Text(text) => {
self.infer_text(ctx, &**text, params)?;
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
while let Some(num_tokens_ready) = generation.next().await {
span!(
Level::DEBUG,
"[EXECUTOR][GENERATE]",
request_id = request_id,
num_tokens_ready = num_tokens_ready
)
.in_scope(|| async {
let ctx = Box::new(GenerationContext(
sender.clone(),
Arc::clone(&generation.done),
));
let mut executor_w = executor.write().await;
info!("Acquired write lock stream");
executor_w.pin_mut().stream_tokens(
request_id,
ctx,
|ctx: Box<GenerationContext>, token: u32, logprob: f32, is_final: bool| {
info!("Sending token: {} (final: {})", token, is_final);
let out = if is_final {
ctx.1.store(true, Ordering::Relaxed);
InferStreamResponse::End {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
generated_text: GeneratedText {
text: "".into(),
generated_tokens: u32::MAX,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: Instant::now(),
queued: Instant::now(),
}
} else {
InferStreamResponse::Intermediate {
token: Token {
id: token,
text: "".into(),
logprob,
special: false,
},
top_tokens: vec![],
}
};
ctx.0
.send(Ok(out))
.expect("Failed to send back generated token");
},
);
info!("Releasing write lock stream")
})
.await;
}
Chunk::Image(_) => panic!("Unsupported"),
};
});
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
self.inner.borrow_mut().is_ready()
true
}
}
// async fn background_looper<P: AsRef<Path>, PP: AsRef<Path>>(
// engine_folder: P,
// _executor_worker: Option<PP>,
// tokenizer: Tokenizer,
// mut receiver: UnboundedReceiver<InferenceContext>,
// ) {
// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), "");
//
// while !(receiver.is_closed()) {
// // Receive the incoming request
// if let Some(ctx) = receiver.recv().await {
// debug!("Processing new incoming request");
//
// // We only support single, textual chunk
// if ctx.request.inputs.len() != 1 {
// propagate!(
// ctx,
// Err(InferError::GenerationError(format!(
// "Unsupported multi-chunk ({}) input",
// ctx.request.inputs.len()
// )))
// );
// }
//
// let input = ctx
// .request
// .inputs
// .first()
// .expect("Single chunk checked above");
// let params = ctx.request.parameters;
// }
// }
// Receive the incoming request
// if let Some(ctx) = receiver.recv().await {
// debug!("Processing new incoming request");
// // We only support single, textual chunk
// if ctx.request.inputs.len() != 1 {
// propagate!(
// ctx,
// Err(InferError::GenerationError(format!(
// "Unsupported multi-chunk ({}) input",
// ctx.request.inputs.len()
// )))
// );
// }
//
// // Unpack parameters
// let inputs = ctx.request.inputs;
// let params = ctx.request.parameters;
//
// match inputs.first().unwrap() {
// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) {
// Err(err) => {
// propagate!(ctx, Err(InferError::GenerationError(err.to_string())))
// }
// Ok(encoding) => {
// // spawn_blocking(|| {
// // info!("Submitting request to TensorRT-LLM executor");
// // let mut executor = backend.blocking_write();
// // })
// // .await
// // .expect("");
// }
// },
// Chunk::Image(_) => propagate!(
// ctx,
// Err(InferError::GenerationError(
// "Image input is not supported yet.".into(),
// ))
// ),
// }
// };
// }

View File

@ -7,6 +7,7 @@
#include <filesystem>
#include <vector>
#include <spdlog/spdlog.h>
#include "backends/trtllm/include/ffi.h"
@ -21,42 +22,43 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens,
int32_t maxNewTokens, int32_t topK, float_t topP,
float_t temperature, uint64_t seed) {
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(tokens.size());
tokens_.assign(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed);
}
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream(
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
uint64_t requestId,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
bool isDone = false;
uint32_t numGeneratedTokens = 0;
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(const uint64_t requestId,
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback) {
do {
const auto responses = Poll(requestId);
for (const auto &response: responses) {
if (response.hasError()) {
isDone = true;
// TODO : bubble up the error to rust
} else {
const auto generation = response.getResult();
const auto token = generation.outputTokenIds[0][0];
isDone = generation.isFinal;
SPDLOG_INFO("Entering StreamTokens");
for (const auto &item: Poll(requestId)) {
if (!item.hasError()) {
SPDLOG_INFO("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
SPDLOG_INFO("\tStreamTokens -> Successfully read decoded token ({})", decoded.outputTokenIds[0].size());
// Propagate through the handler
handler(std::move(ctx), token, numGeneratedTokens, isDone);
}
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
// const auto logProb = decoded.logProbs.value()[0][0];
const auto logProb = 0.0;
SPDLOG_INFO(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
callback(std::move(ctx), token, logProb, isFinal);
SPDLOG_INFO("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
callback(std::move(ctx), 0, 0.0, true);
}
} while (!isDone);
}
return numGeneratedTokens;
SPDLOG_INFO("Exiting StreamTokens");
return 0;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>

View File

@ -17,7 +17,7 @@ mod ffi {
/// Represent an instance of the underlying TensorRT-LLM backend
type TensorRtLlmBackendImpl;
/// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
///
/// # Arguments
///
@ -37,29 +37,31 @@ mod ffi {
executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>;
#[rust_name = "is_ready"]
fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
#[rust_name = "submit"]
fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: i32,
top_k: i32,
top_p: f32,
temperature: f32,
seed: u64,
) -> u64;
#[rust_name = "stream"]
fn Stream(
#[rust_name = "stream_tokens"]
fn StreamTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
ctx: Box<GenerationContext>,
request_id: u64,
callback: fn(Box<GenerationContext>, u32, u32, bool),
) -> u32;
ctx: Box<GenerationContext>,
cb: fn(Box<GenerationContext>, u32, f32, bool),
) -> usize;
#[rust_name = "shutdown"]
fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
}
}

View File

@ -1,9 +1,11 @@
use std::collections::HashMap;
use std::path::PathBuf;
use clap::Parser;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server;
/// App Configuration
@ -53,7 +55,13 @@ struct Args {
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>
auth_token: Option<String>,
#[clap(
long,
env,
help = "Path to the TensorRT-LLM Orchestrator Worker binary"
)]
executor_worker: Option<PathBuf>,
}
#[tokio::main]
@ -83,7 +91,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token
auth_token,
executor_worker,
} = args;
// Launch Tokio runtime
@ -114,6 +123,15 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
}
}
if let Some(ref executor_worker) = executor_worker {
if !executor_worker.exists() {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
"`executor_work` specified path doesn't exists: {}",
executor_worker.display()
)));
}
}
// Run server
let tokenizer = Tokenizer::from_pretrained(
tokenizer_name.clone(),
@ -122,9 +140,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
user_agent: HashMap::new(),
auth_token,
}),
).map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
)
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
let backend = TrtLLmBackend::new(tokenizer, model_id)?;
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
server::run(
backend,
max_concurrent_requests,