impl RwLock scenario for TensorRtLllmBackend
This commit is contained in:
parent
31d9f4d5dc
commit
7784a21d48
|
@ -15,6 +15,10 @@ tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot"
|
|||
tokio-stream = "0.1.14"
|
||||
clap = { version = "4.5.4", features = ["derive"] }
|
||||
thiserror = "1.0.61"
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.24"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
log = { version = "0.4.21", features = [] }
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
|
|
|
@ -33,6 +33,8 @@ fn main() {
|
|||
"debug" => format!("{}d", dependency),
|
||||
_ => String::from(dependency),
|
||||
};
|
||||
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,14 +17,11 @@ else ()
|
|||
set(FAST_BUILD OFF)
|
||||
endif ()
|
||||
|
||||
# This line turn off DEBUG in TRTLLM logger which is quite spammy
|
||||
add_compile_definitions(NDEBUG OFF)
|
||||
|
||||
fetchcontent_declare(
|
||||
trtllm
|
||||
GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git
|
||||
GIT_TAG a96cccafcf6365c128f004f779160951f8c0801c
|
||||
GIT_SHALLOW TRUE
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||
GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c
|
||||
GIT_SHALLOW FALSE
|
||||
)
|
||||
fetchcontent_makeavailable(trtllm)
|
||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||
|
||||
//#include "rust/cxx.h"
|
||||
#include <cstddef>
|
||||
#include "backend.h"
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
|
@ -17,9 +17,9 @@ namespace huggingface::tgi::backends {
|
|||
|
||||
namespace huggingface::tgi::backends {
|
||||
|
||||
struct GenerationContext;
|
||||
// struct GenerationContext;
|
||||
|
||||
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
|
||||
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||
public:
|
||||
/***
|
||||
*
|
||||
|
@ -37,7 +37,6 @@ namespace huggingface::tgi::backends {
|
|||
/***
|
||||
*
|
||||
* @param tokens
|
||||
* @param maxNewTokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
|
@ -45,17 +44,20 @@ namespace huggingface::tgi::backends {
|
|||
* @return
|
||||
*/
|
||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||
uint64_t Submit(rust::Slice<const uint32_t> tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
|
||||
uint64_t
|
||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
|
||||
|
||||
/***
|
||||
*
|
||||
* @param requestId
|
||||
* @param handler
|
||||
* @param ctx
|
||||
* @param callback
|
||||
* @return
|
||||
*/
|
||||
uint32_t Stream(rust::Box <GenerationContext> ctx,
|
||||
uint64_t requestId,
|
||||
rust::Fn<void(rust::Box<GenerationContext>, uint32_t, uint32_t, bool)> handler);
|
||||
size_t StreamTokens(
|
||||
const RequestId requestId,
|
||||
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
||||
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback);
|
||||
};
|
||||
|
||||
/***
|
||||
|
|
|
@ -1,160 +1,311 @@
|
|||
use std::cell::RefCell;
|
||||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use log::{info, warn};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{Instant, sleep};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{instrument, Level, span};
|
||||
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidParameters};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
|
||||
// macro_rules! propagate {
|
||||
// ($ctx: expr, $res: expr) => {
|
||||
// $ctx.sender
|
||||
// .send($res)
|
||||
// .expect("Failed to propagate error back to the transport layer")
|
||||
// };
|
||||
// }
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
pub struct GenerationContext(mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>);
|
||||
/// Holds the user provided input to be executed along with a channel allowing
|
||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||
// pub struct InferenceContext {
|
||||
// /// User provided request
|
||||
// request: ValidGenerateRequest,
|
||||
//
|
||||
// /// Inter-process communication handler moving token from the executor thread to the HTTP server
|
||||
// sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
//
|
||||
// /// Pin the instant this inference context was submitted
|
||||
// when: Instant,
|
||||
//
|
||||
// /// Span that will live as long as entry
|
||||
// span: Span,
|
||||
// }
|
||||
|
||||
pub struct TrtLLmBackend {
|
||||
tokenizer: Tokenizer,
|
||||
inner: RefCell<UniquePtr<TensorRtLlmBackendImpl>>,
|
||||
pub(crate) struct Generation {
|
||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
done: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
unsafe impl Sync for TrtLLmBackend {}
|
||||
unsafe impl Send for TrtLLmBackend {}
|
||||
pub struct GenerationContext(
|
||||
UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
Arc<AtomicBool>,
|
||||
);
|
||||
|
||||
impl TrtLLmBackend {
|
||||
pub fn new<P: AsRef<Path>>(
|
||||
tokenizer: Tokenizer,
|
||||
engine_folder: P,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
let engine_folder = engine_folder.as_ref();
|
||||
let inner = create_tensorrt_llm_backend(engine_folder.to_str().unwrap(), "");
|
||||
impl Stream for Generation {
|
||||
type Item = usize;
|
||||
|
||||
Ok(Self {
|
||||
tokenizer,
|
||||
inner: RefCell::new(inner),
|
||||
})
|
||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if self.done.load(Ordering::Relaxed) {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
let pinned = pin!(self.executor.read());
|
||||
match pinned.poll(ctx) {
|
||||
Poll::Ready(executor_r) => {
|
||||
let ready = executor_r.num_responses_ready();
|
||||
if ready == 0 {
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
waker.wake();
|
||||
});
|
||||
Poll::Pending
|
||||
} else {
|
||||
info!("Ready: {}", ready);
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
waker.wake();
|
||||
});
|
||||
Poll::Ready(Some(ready))
|
||||
}
|
||||
}
|
||||
Poll::Pending => {
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
waker.wake();
|
||||
});
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_text(
|
||||
&self,
|
||||
ctx: GenerationContext,
|
||||
text: &str,
|
||||
params: ValidParameters,
|
||||
) -> InferResult<()> {
|
||||
// Keep track of processing time
|
||||
let start = Instant::now();
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
(1, None)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode the input
|
||||
let ctx = Box::new(ctx);
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, true)
|
||||
.map_err(|e| InferError::ToolError(e.to_string()))?;
|
||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||
|
||||
// Submit the request to the backend and retrieve the handle to query its status
|
||||
let request_id = self
|
||||
.inner
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.expect("Failed to retrieve pointer to TRTLLM backend")
|
||||
.submit(
|
||||
encoding.get_ids(),
|
||||
128,
|
||||
params.top_k as i32,
|
||||
params.top_p,
|
||||
params.temperature,
|
||||
params.seed,
|
||||
);
|
||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||
pub struct TensorRtLlmBackend {
|
||||
// Allowing sending user requests to the TensorRT-LLM executor thread
|
||||
// batcher: UnboundedSender<InferenceContext>,
|
||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
}
|
||||
|
||||
// Stream generated tokens
|
||||
// spawn_blocking(move || {
|
||||
let num_generated_tokens = self
|
||||
.inner
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.expect("Failed to retrieve pointer to TRTLLM backend")
|
||||
.stream(ctx, request_id, |ctx, token, step, is_final| {
|
||||
// self.tokenizer.decode(&*[token], true).unwrap();
|
||||
let sender = ctx.0;
|
||||
let token = Token {
|
||||
id: token,
|
||||
text: String::from(""),
|
||||
logprob: 1.0f32,
|
||||
special: false,
|
||||
};
|
||||
|
||||
sender
|
||||
.send(Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// Notify the end
|
||||
let _ = ctx.0.send(Ok(InferStreamResponse::End {
|
||||
token: Token {
|
||||
id: 0,
|
||||
text: String::from(""),
|
||||
logprob: 1.0f32,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: String::from(""),
|
||||
generated_tokens: num_generated_tokens,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: Some(params.seed),
|
||||
},
|
||||
start,
|
||||
queued: Instant::now(),
|
||||
}));
|
||||
// });
|
||||
|
||||
Ok(())
|
||||
impl TensorRtLlmBackend {
|
||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||
_tokenizer: Tokenizer,
|
||||
engine_folder: P,
|
||||
_executor_worker_path: Option<PP>,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
Ok(TensorRtLlmBackend {
|
||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||
engine_folder.as_ref().to_str().unwrap(),
|
||||
"",
|
||||
))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TrtLLmBackend {
|
||||
impl Backend for TensorRtLlmBackend {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
_request: ValidGenerateRequest,
|
||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
let ctx = GenerationContext(sender);
|
||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
|
||||
// Unpack parameters
|
||||
let params = request.parameters;
|
||||
let executor = self.backend.clone();
|
||||
tokio::spawn(async move {
|
||||
// Submit the request to the batcher
|
||||
let request_id = span!(Level::DEBUG, "[EXECUTOR][SUBMIT]")
|
||||
.in_scope(|| async {
|
||||
info!("Acquiring lock for submit");
|
||||
let mut handle = executor.write().await;
|
||||
let request_id = handle.pin_mut().submit(
|
||||
&vec![2, 2926, 1503, 603, 20189],
|
||||
50,
|
||||
1.0,
|
||||
1.0,
|
||||
2014,
|
||||
);
|
||||
|
||||
// Ensure we are running in the right conditions for the input (i.e. single textual chunk)
|
||||
let input = match request.inputs.len() {
|
||||
0 => Err(InferError::GenerationError("No input provided".into())),
|
||||
1 => Ok(request.inputs.first().unwrap()),
|
||||
_ => Err(InferError::GenerationError(format!(
|
||||
"Unsupported multi-chunks ({}) inference.",
|
||||
request.inputs.len()
|
||||
))),
|
||||
}?;
|
||||
info!("Releasing lock for submit");
|
||||
return request_id;
|
||||
})
|
||||
.await;
|
||||
|
||||
// Currently we handle single chunk of text
|
||||
match input {
|
||||
Chunk::Text(text) => {
|
||||
self.infer_text(ctx, &**text, params)?;
|
||||
let mut generation = Generation {
|
||||
executor: executor.clone(),
|
||||
done: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
while let Some(num_tokens_ready) = generation.next().await {
|
||||
span!(
|
||||
Level::DEBUG,
|
||||
"[EXECUTOR][GENERATE]",
|
||||
request_id = request_id,
|
||||
num_tokens_ready = num_tokens_ready
|
||||
)
|
||||
.in_scope(|| async {
|
||||
let ctx = Box::new(GenerationContext(
|
||||
sender.clone(),
|
||||
Arc::clone(&generation.done),
|
||||
));
|
||||
let mut executor_w = executor.write().await;
|
||||
|
||||
info!("Acquired write lock stream");
|
||||
executor_w.pin_mut().stream_tokens(
|
||||
request_id,
|
||||
ctx,
|
||||
|ctx: Box<GenerationContext>, token: u32, logprob: f32, is_final: bool| {
|
||||
info!("Sending token: {} (final: {})", token, is_final);
|
||||
let out = if is_final {
|
||||
ctx.1.store(true, Ordering::Relaxed);
|
||||
InferStreamResponse::End {
|
||||
token: Token {
|
||||
id: token,
|
||||
text: "".into(),
|
||||
logprob,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: "".into(),
|
||||
generated_tokens: u32::MAX,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: Instant::now(),
|
||||
queued: Instant::now(),
|
||||
}
|
||||
} else {
|
||||
InferStreamResponse::Intermediate {
|
||||
token: Token {
|
||||
id: token,
|
||||
text: "".into(),
|
||||
logprob,
|
||||
special: false,
|
||||
},
|
||||
top_tokens: vec![],
|
||||
}
|
||||
};
|
||||
ctx.0
|
||||
.send(Ok(out))
|
||||
.expect("Failed to send back generated token");
|
||||
},
|
||||
);
|
||||
info!("Releasing write lock stream")
|
||||
})
|
||||
.await;
|
||||
}
|
||||
Chunk::Image(_) => panic!("Unsupported"),
|
||||
};
|
||||
});
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
}
|
||||
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
self.inner.borrow_mut().is_ready()
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
// async fn background_looper<P: AsRef<Path>, PP: AsRef<Path>>(
|
||||
// engine_folder: P,
|
||||
// _executor_worker: Option<PP>,
|
||||
// tokenizer: Tokenizer,
|
||||
// mut receiver: UnboundedReceiver<InferenceContext>,
|
||||
// ) {
|
||||
// let mut backend = create_tensorrt_llm_backend(engine_folder.as_ref().to_str().unwrap(), "");
|
||||
//
|
||||
// while !(receiver.is_closed()) {
|
||||
// // Receive the incoming request
|
||||
// if let Some(ctx) = receiver.recv().await {
|
||||
// debug!("Processing new incoming request");
|
||||
//
|
||||
// // We only support single, textual chunk
|
||||
// if ctx.request.inputs.len() != 1 {
|
||||
// propagate!(
|
||||
// ctx,
|
||||
// Err(InferError::GenerationError(format!(
|
||||
// "Unsupported multi-chunk ({}) input",
|
||||
// ctx.request.inputs.len()
|
||||
// )))
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// let input = ctx
|
||||
// .request
|
||||
// .inputs
|
||||
// .first()
|
||||
// .expect("Single chunk checked above");
|
||||
// let params = ctx.request.parameters;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Receive the incoming request
|
||||
// if let Some(ctx) = receiver.recv().await {
|
||||
// debug!("Processing new incoming request");
|
||||
|
||||
// // We only support single, textual chunk
|
||||
// if ctx.request.inputs.len() != 1 {
|
||||
// propagate!(
|
||||
// ctx,
|
||||
// Err(InferError::GenerationError(format!(
|
||||
// "Unsupported multi-chunk ({}) input",
|
||||
// ctx.request.inputs.len()
|
||||
// )))
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// // Unpack parameters
|
||||
// let inputs = ctx.request.inputs;
|
||||
// let params = ctx.request.parameters;
|
||||
//
|
||||
// match inputs.first().unwrap() {
|
||||
// Chunk::Text(text) => match tokenizer.encode(text.as_str(), true) {
|
||||
// Err(err) => {
|
||||
// propagate!(ctx, Err(InferError::GenerationError(err.to_string())))
|
||||
// }
|
||||
// Ok(encoding) => {
|
||||
// // spawn_blocking(|| {
|
||||
// // info!("Submitting request to TensorRT-LLM executor");
|
||||
// // let mut executor = backend.blocking_write();
|
||||
// // })
|
||||
// // .await
|
||||
// // .expect("");
|
||||
// }
|
||||
// },
|
||||
// Chunk::Image(_) => propagate!(
|
||||
// ctx,
|
||||
// Err(InferError::GenerationError(
|
||||
// "Image input is not supported yet.".into(),
|
||||
// ))
|
||||
// ),
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <filesystem>
|
||||
#include <vector>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "backends/trtllm/include/ffi.h"
|
||||
|
||||
|
||||
|
@ -21,42 +22,43 @@ bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
|||
}
|
||||
|
||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||
rust::Slice<const uint32_t> tokens,
|
||||
int32_t maxNewTokens, int32_t topK, float_t topP,
|
||||
float_t temperature, uint64_t seed) {
|
||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed) {
|
||||
|
||||
// This will copy all the items from the initial slice
|
||||
std::vector<int32_t> tokens_(tokens.size());
|
||||
tokens_.assign(tokens.begin(), tokens.end());
|
||||
|
||||
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
|
||||
return TensorRtLlmBackend::Submit(std::move(tokens_), topK, topP, temperature, seed);
|
||||
}
|
||||
|
||||
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream(
|
||||
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
||||
uint64_t requestId,
|
||||
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
|
||||
bool isDone = false;
|
||||
uint32_t numGeneratedTokens = 0;
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(const uint64_t requestId,
|
||||
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
||||
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, float_t, bool)> callback) {
|
||||
|
||||
do {
|
||||
const auto responses = Poll(requestId);
|
||||
for (const auto &response: responses) {
|
||||
if (response.hasError()) {
|
||||
isDone = true;
|
||||
// TODO : bubble up the error to rust
|
||||
} else {
|
||||
const auto generation = response.getResult();
|
||||
const auto token = generation.outputTokenIds[0][0];
|
||||
isDone = generation.isFinal;
|
||||
SPDLOG_INFO("Entering StreamTokens");
|
||||
for (const auto &item: Poll(requestId)) {
|
||||
if (!item.hasError()) {
|
||||
SPDLOG_INFO("\tStreamTokens -> Decoding token...");
|
||||
const auto decoded = item.getResult();
|
||||
SPDLOG_INFO("\tStreamTokens -> Successfully read decoded token ({})", decoded.outputTokenIds[0].size());
|
||||
|
||||
// Propagate through the handler
|
||||
handler(std::move(ctx), token, numGeneratedTokens, isDone);
|
||||
}
|
||||
const auto token = decoded.outputTokenIds[0][0];
|
||||
const auto isFinal = decoded.isFinal;
|
||||
// const auto logProb = decoded.logProbs.value()[0][0];
|
||||
const auto logProb = 0.0;
|
||||
|
||||
SPDLOG_INFO(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
callback(std::move(ctx), token, logProb, isFinal);
|
||||
SPDLOG_INFO("\tStreamTokens -> Post callback");
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
|
||||
callback(std::move(ctx), 0, 0.0, true);
|
||||
}
|
||||
} while (!isDone);
|
||||
}
|
||||
|
||||
return numGeneratedTokens;
|
||||
SPDLOG_INFO("Exiting StreamTokens");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
|
|
|
@ -17,7 +17,7 @@ mod ffi {
|
|||
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||
type TensorRtLlmBackendImpl;
|
||||
|
||||
/// Create an instance backed behind an std::unique_ptr to manage the lifespan of the backend
|
||||
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -37,29 +37,31 @@ mod ffi {
|
|||
executor_worker: &str,
|
||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||
|
||||
#[rust_name = "is_ready"]
|
||||
fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||
// #[rust_name = "is_ready"]
|
||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||
|
||||
#[rust_name = "num_responses_ready"]
|
||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||
|
||||
#[rust_name = "submit"]
|
||||
fn Submit(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
tokens: &[u32],
|
||||
max_new_tokens: i32,
|
||||
top_k: i32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
seed: u64,
|
||||
) -> u64;
|
||||
|
||||
#[rust_name = "stream"]
|
||||
fn Stream(
|
||||
#[rust_name = "stream_tokens"]
|
||||
fn StreamTokens(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
ctx: Box<GenerationContext>,
|
||||
request_id: u64,
|
||||
callback: fn(Box<GenerationContext>, u32, u32, bool),
|
||||
) -> u32;
|
||||
ctx: Box<GenerationContext>,
|
||||
cb: fn(Box<GenerationContext>, u32, f32, bool),
|
||||
) -> usize;
|
||||
|
||||
#[rust_name = "shutdown"]
|
||||
fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||
// #[rust_name = "shutdown"]
|
||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
||||
use text_generation_backends_trtllm::{errors::TensorRtLlmBackendError, TrtLLmBackend};
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||
use text_generation_router::server;
|
||||
|
||||
/// App Configuration
|
||||
|
@ -53,7 +55,13 @@ struct Args {
|
|||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
auth_token: Option<String>
|
||||
auth_token: Option<String>,
|
||||
#[clap(
|
||||
long,
|
||||
env,
|
||||
help = "Path to the TensorRT-LLM Orchestrator Worker binary"
|
||||
)]
|
||||
executor_worker: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
@ -83,7 +91,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
cors_allow_origin,
|
||||
messages_api_enabled,
|
||||
max_client_batch_size,
|
||||
auth_token
|
||||
auth_token,
|
||||
executor_worker,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
|
@ -114,6 +123,15 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(ref executor_worker) = executor_worker {
|
||||
if !executor_worker.exists() {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
|
||||
"`executor_work` specified path doesn't exists: {}",
|
||||
executor_worker.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Run server
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
tokenizer_name.clone(),
|
||||
|
@ -122,9 +140,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
|||
user_agent: HashMap::new(),
|
||||
auth_token,
|
||||
}),
|
||||
).map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||
)
|
||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||
|
||||
let backend = TrtLLmBackend::new(tokenizer, model_id)?;
|
||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
|
|
Loading…
Reference in New Issue