end to end ffi flow working
This commit is contained in:
parent
b846ae2d9e
commit
344f33f398
|
@ -0,0 +1,69 @@
|
||||||
|
//
|
||||||
|
// Created by mfuntowicz on 7/11/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
|
||||||
|
//#include "rust/cxx.h"
|
||||||
|
#include "backend.h"
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends {
|
||||||
|
class TensorRtLlmBackendImpl;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "backends/trtllm/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
|
struct GenerationContext;
|
||||||
|
|
||||||
|
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
|
||||||
|
public:
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param engineFolder
|
||||||
|
* @param executorWorker
|
||||||
|
*/
|
||||||
|
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
bool IsReady() const;
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param tokens
|
||||||
|
* @param maxNewTokens
|
||||||
|
* @param topK
|
||||||
|
* @param topP
|
||||||
|
* @param temperature
|
||||||
|
* @param seed
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
|
uint64_t Submit(rust::Slice<const uint32_t> tokens, int32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, uint64_t seed);
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param requestId
|
||||||
|
* @param handler
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
uint32_t Stream(rust::Box <GenerationContext> ctx,
|
||||||
|
uint64_t requestId,
|
||||||
|
rust::Fn<void(rust::Box<GenerationContext>, uint32_t, uint32_t, bool)> handler);
|
||||||
|
};
|
||||||
|
|
||||||
|
/***
|
||||||
|
*
|
||||||
|
* @param engineFolder
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //TGI_TRTLLM_BACKEND_FFI_H
|
|
@ -7,93 +7,66 @@
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "rust/cxx.h"
|
//#include "rust/cxx.h"
|
||||||
#include "backends/trtllm/include/backend.h"
|
//#include "../include/ffi.h"
|
||||||
|
#include "backends/trtllm/include/ffi.h"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
|
||||||
class TensorRtLlmBackendImpl : TensorRtLlmBackend {
|
|
||||||
public:
|
|
||||||
/***
|
|
||||||
*
|
|
||||||
* @param engineFolder
|
|
||||||
* @param executorWorker
|
|
||||||
*/
|
|
||||||
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker) :
|
|
||||||
TensorRtLlmBackend(std::move(engineFolder), std::move(executorWorker)) {}
|
|
||||||
|
|
||||||
/***
|
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||||
*
|
const std::string_view &engineFolder,
|
||||||
* @return
|
const std::string_view &executorWorker
|
||||||
*/
|
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||||
bool IsReady() const { return TensorRtLlmBackend::IsReady(); }
|
|
||||||
|
|
||||||
|
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
||||||
|
return TensorRtLlmBackend::IsReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
|
rust::Slice<const uint32_t> tokens,
|
||||||
|
int32_t maxNewTokens, int32_t topK, float_t topP,
|
||||||
|
float_t temperature, uint64_t seed) {
|
||||||
|
|
||||||
/***
|
|
||||||
*
|
|
||||||
* @param tokens
|
|
||||||
* @param maxNewTokens
|
|
||||||
* @param topK
|
|
||||||
* @param topP
|
|
||||||
* @param temperature
|
|
||||||
* @param seed
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
|
||||||
RequestId Submit(rust::Slice<const uint32_t> tokens,
|
|
||||||
int32_t maxNewTokens,
|
|
||||||
int32_t topK,
|
|
||||||
float_t topP,
|
|
||||||
float_t temperature,
|
|
||||||
uint64_t seed) {
|
|
||||||
// This will copy all the items from the initial slice
|
// This will copy all the items from the initial slice
|
||||||
std::vector<int32_t> tokens_(tokens.size());
|
std::vector<int32_t> tokens_(tokens.size());
|
||||||
tokens_.assign(tokens.begin(), tokens.end());
|
tokens_.assign(tokens.begin(), tokens.end());
|
||||||
|
|
||||||
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
|
return TensorRtLlmBackend::Submit(std::move(tokens_), maxNewTokens, topK, topP, temperature, seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Stream(
|
||||||
|
rust::Box<huggingface::tgi::backends::GenerationContext> ctx,
|
||||||
|
uint64_t requestId,
|
||||||
|
rust::Fn<void(rust::Box<huggingface::tgi::backends::GenerationContext>, uint32_t, uint32_t, bool)> handler) {
|
||||||
|
bool isDone = false;
|
||||||
|
uint32_t numGeneratedTokens = 0;
|
||||||
|
|
||||||
|
do {
|
||||||
|
const auto responses = Poll(requestId);
|
||||||
|
for (const auto &response: responses) {
|
||||||
|
if (response.hasError()) {
|
||||||
|
isDone = true;
|
||||||
|
// TODO : bubble up the error to rust
|
||||||
|
} else {
|
||||||
|
const auto generation = response.getResult();
|
||||||
|
const auto token = generation.outputTokenIds[0][0];
|
||||||
|
isDone = generation.isFinal;
|
||||||
|
|
||||||
|
// Propagate through the handler
|
||||||
|
handler(std::move(ctx), token, numGeneratedTokens, isDone);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} while (!isDone);
|
||||||
|
|
||||||
/***
|
return numGeneratedTokens;
|
||||||
*
|
}
|
||||||
* @param requestId
|
|
||||||
* @param handler
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
// uint32_t
|
|
||||||
// Stream(RequestId requestId, rust::Box <GenerationContext>, rust::Fn<void(uint32_t, uint32_t, bool)> handler) {
|
|
||||||
// bool isDone = false;
|
|
||||||
// uint32_t numGeneratedTokens = 0;
|
|
||||||
//
|
|
||||||
// do {
|
|
||||||
// const auto responses = Poll(requestId);
|
|
||||||
// for (const auto &response: responses) {
|
|
||||||
// if (response.hasError()) {
|
|
||||||
// isDone = true;
|
|
||||||
// // TODO : bubble up the error to rust
|
|
||||||
// } else {
|
|
||||||
// const auto generation = response.getResult();
|
|
||||||
// const auto token = generation.outputTokenIds[0][0];
|
|
||||||
// isDone = generation.isFinal;
|
|
||||||
//
|
|
||||||
// // Propagate through the handler
|
|
||||||
// handler(token, numGeneratedTokens, isDone);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// } while (!isDone);
|
|
||||||
//
|
|
||||||
// return numGeneratedTokens;
|
|
||||||
// }
|
|
||||||
};
|
|
||||||
|
|
||||||
/***
|
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||||
*
|
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||||
* @param engineFolder
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
std::unique_ptr<TensorRtLlmBackendImpl> create_trtllm_backend(rust::Str engineFolder, rust::Str executorWorker) {
|
|
||||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||||
InitializeBackend();
|
InitializeBackend();
|
||||||
|
|
||||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||||
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
||||||
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
||||||
}
|
|
||||||
}
|
}
|
|
@ -1,10 +1,16 @@
|
||||||
pub use backend::TrtLLmBackend;
|
pub use backend::TrtLLmBackend;
|
||||||
|
|
||||||
|
use crate::backend::GenerationContext;
|
||||||
|
|
||||||
mod backend;
|
mod backend;
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
extern "Rust" {
|
||||||
|
type GenerationContext;
|
||||||
|
}
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/trtllm/src/ffi.cpp");
|
include!("backends/trtllm/src/ffi.cpp");
|
||||||
|
|
||||||
|
@ -25,7 +31,8 @@ mod ffi {
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
fn create_trtllm_backend(
|
#[rust_name = "create_tensorrt_llm_backend"]
|
||||||
|
fn CreateTensorRtLlmBackend(
|
||||||
engine_folder: &str,
|
engine_folder: &str,
|
||||||
executor_worker: &str,
|
executor_worker: &str,
|
||||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||||
|
@ -44,12 +51,12 @@ mod ffi {
|
||||||
seed: u64,
|
seed: u64,
|
||||||
) -> u64;
|
) -> u64;
|
||||||
|
|
||||||
// #[rust_name = "stream"]
|
#[rust_name = "stream"]
|
||||||
// fn Stream(
|
fn Stream(
|
||||||
// self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
// request_id: u64,
|
ctx: Box<GenerationContext>,
|
||||||
// ctx: Box<GenerationContext>,
|
request_id: u64,
|
||||||
// callback: fn(u32, u32, bool),
|
callback: fn(Box<GenerationContext>, u32, u32, bool),
|
||||||
// ) -> u32;
|
) -> u32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue