end to end ffi flow working

This commit is contained in:
Morgan Funtowicz 2024-07-12 19:25:40 +00:00
parent b846ae2d9e
commit 344f33f398
3 changed files with 139 additions and 90 deletions

View File

@ -0,0 +1,69 @@
//
// Created by mfuntowicz on 7/11/24.
//
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
//#include "rust/cxx.h"
#include "backend.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl;
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends {
struct GenerationContext;
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
public:
/***
*
* @param engineFolder
* @param executorWorker
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/***
*
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
* @param seed
* @return
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t Submit(rust::Slice<const uint32_t> tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
/***
*
* @param requestId
* @param handler
* @return
*/
uint32_t Stream(rust::Box <GenerationContext> ctx,
uint64_t requestId,
rust::Fn<void(rust::Box<GenerationContext>, uint32_t, uint32_t, bool)> handler);
};
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
}
#endif //TGI_TRTLLM_BACKEND_FFI_H

View File

@ -7,93 +7,66 @@
#include <filesystem> #include <filesystem>
#include <vector> #include <vector>
#include "rust/cxx.h" //#include "rust/cxx.h"
#include "backends/trtllm/include/backend.h" //#include "../include/ffi.h"
#include "backends/trtllm/include/ffi.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
public:
/***
*
* @param engineFolder
* @param executorWorker
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker) :
TensorRtLlmBackend(std::move(engineFolder), std::move(executorWorker)) {}
/*** huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
* const std::string_view &engineFolder,
* @return const std::string_view &executorWorker
*/ ) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool IsReady() const { return TensorRtLlmBackend::IsReady(); }
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens,
int32_t maxNewTokens, int32_t topK, float_t topP,
float_t temperature, uint64_t seed) {
/***
*
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
* @param seed
* @return
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
RequestId Submit(rust::Slice<const uint32_t> tokens,
int32_t maxNewTokens,
int32_t topK,
float_t topP,
float_t temperature,
uint64_t seed) {
// This will copy all the items from the initial slice // This will copy all the items from the initial slice
std::vector<int32_t> tokens_(tokens.size()); std::vector<int32_t> tokens_(tokens.size());
tokens_.assign(tokens.begin(), tokens.end()); tokens_.assign(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed); return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
}
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream(
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
uint64_t requestId,
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
bool isDone = false;
uint32_t numGeneratedTokens = 0;
do {
const auto responses = Poll(requestId);
for (const auto &response: responses) {
if (response.hasError()) {
isDone = true;
// TODO : bubble up the error to rust
} else {
const auto generation = response.getResult();
const auto token = generation.outputTokenIds[0][0];
isDone = generation.isFinal;
// Propagate through the handler
handler(std::move(ctx), token, numGeneratedTokens, isDone);
} }
}
} while (!isDone);
/*** return numGeneratedTokens;
* }
* @param requestId
* @param handler
* @return
*/
// uint32_t
// Stream(RequestId requestId, rust::Box <GenerationContext>, rust::Fn<void(uint32_t, uint32_t, bool)> handler) {
// bool isDone = false;
// uint32_t numGeneratedTokens = 0;
//
// do {
// const auto responses = Poll(requestId);
// for (const auto &response: responses) {
// if (response.hasError()) {
// isDone = true;
// // TODO : bubble up the error to rust
// } else {
// const auto generation = response.getResult();
// const auto token = generation.outputTokenIds[0][0];
// isDone = generation.isFinal;
//
// // Propagate through the handler
// handler(token, numGeneratedTokens, isDone);
// }
// }
// } while (!isDone);
//
// return numGeneratedTokens;
// }
};
/*** std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
* huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
* @param engineFolder
* @return
*/
std::unique_ptr<TensorRtLlmBackendImpl> create_trtllm_backend(rust::Str engineFolder, rust::Str executorWorker) {
// Unconditionally call this to initialize and discover TRTLLM plugins // Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend(); InitializeBackend();
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end()); const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end()); const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath)); return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
}
} }

View File

@ -1,10 +1,16 @@
pub use backend::TrtLLmBackend; pub use backend::TrtLLmBackend;
use crate::backend::GenerationContext;
mod backend; mod backend;
pub mod errors; pub mod errors;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp"); include!("backends/trtllm/src/ffi.cpp");
@ -25,7 +31,8 @@ mod ffi {
/// ``` /// ```
/// ///
/// ``` /// ```
fn create_trtllm_backend( #[rust_name = "create_tensorrt_llm_backend"]
fn CreateTensorRtLlmBackend(
engine_folder: &str, engine_folder: &str,
executor_worker: &str, executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>; ) -> UniquePtr<TensorRtLlmBackendImpl>;
@ -44,12 +51,12 @@ mod ffi {
seed: u64, seed: u64,
) -> u64; ) -> u64;
// #[rust_name = "stream"] #[rust_name = "stream"]
// fn Stream( fn Stream(
// self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
// request_id: u64, ctx: Box<GenerationContext>,
// ctx: Box<GenerationContext>, request_id: u64,
// callback: fn(u32, u32, bool), callback: fn(Box<GenerationContext>, u32, u32, bool),
// ) -> u32; ) -> u32;
} }
} }