Initial setup for CXX binding to TRTLLM
This commit is contained in:
parent
230f2a415a
commit
dc402dc9ac
|
@ -565,6 +565,25 @@ version = "0.7.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.50"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codespan-reporting"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
|
||||
dependencies = [
|
||||
"termcolor",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "color_quant"
|
||||
version = "1.1.0"
|
||||
|
@ -684,6 +703,50 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cxx"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cxxbridge-flags",
|
||||
"cxxbridge-macro",
|
||||
"link-cplusplus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cxx-build"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"codespan-reporting",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"scratch",
|
||||
"syn 2.0.66",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cxxbridge-flags"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd"
|
||||
|
||||
[[package]]
|
||||
name = "cxxbridge-macro"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.66",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.9"
|
||||
|
@ -1615,6 +1678,15 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "link-cplusplus"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.14"
|
||||
|
@ -3040,6 +3112,12 @@ version = "1.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "scratch"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.1"
|
||||
|
@ -3367,6 +3445,29 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version = "2.0.5-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"cmake",
|
||||
"cxx",
|
||||
"cxx-build",
|
||||
"text-generation-router",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.0.5-dev0"
|
||||
|
|
|
@ -5,7 +5,7 @@ members = [
|
|||
# "backends/client",
|
||||
"backends/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
, "backends/trtllm"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
cmake_minimum_required(VERSION 3.26)
|
||||
|
||||
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||
|
||||
|
||||
#### External dependencies ####
|
||||
# Logging library (SPDLOG)
|
||||
set(SPDLOG_USE_FMT ON)
|
||||
fetchcontent_declare(
|
||||
spdlog
|
||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||
GIT_TAG v2.x
|
||||
)
|
||||
fetchcontent_makeavailable(spdlog)
|
||||
|
||||
# TensorRT-LLM
|
||||
fetchcontent_declare(
|
||||
trtllm
|
||||
GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git
|
||||
GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c
|
||||
)
|
||||
fetchcontent_populate(trtllm)
|
||||
include_directories("${trtllm_SOURCE_DIR}/cpp/include")
|
||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||
|
||||
# TGI TRTLLM Backend definition
|
||||
add_library(tgi_trtllm_backend_impl include/backend.h lib/backend.cpp)
|
||||
|
||||
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE spdlog)
|
||||
|
||||
#### Unit Tests ####
|
||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||
message(STATUS "Building tests")
|
||||
FetchContent_Declare(
|
||||
Catch2
|
||||
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||
GIT_TAG v3.6.0
|
||||
)
|
||||
FetchContent_MakeAvailable(Catch2)
|
||||
|
||||
add_executable(tgi_trtllm_backend_tests)
|
||||
target_link_libraries(tests PRIVATE Catch2::Catch2::Catch2WithMain)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||
include(CTest)
|
||||
include(Catch)
|
||||
catch_discover_tests(tests)
|
||||
endif ()
|
|
@ -0,0 +1,18 @@
|
|||
[package]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
cxx = "1.0"
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
cxx-build = "1.0"
|
|
@ -0,0 +1,19 @@
|
|||
use cxx_build::CFG;
|
||||
|
||||
fn main() {
|
||||
let backend_path = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
.build_target("tgi_trtllm_backend_impl")
|
||||
.build();
|
||||
|
||||
CFG.include_prefix = "backends/trtllm";
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.file("src/ffi.cpp")
|
||||
.std("c++20")
|
||||
.compile("tgi_trtllm_backend");
|
||||
|
||||
println!("cargo:rerun-if-changed=include/backend.h");
|
||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||
// println!("cargo:rustc-link-lib=tgi_trtllm_backend_impl");
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
//
|
||||
// Created by Morgan Funtowicz on 6/30/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_H
|
||||
#define TGI_TRTLLM_BACKEND_H
|
||||
|
||||
#include <filesystem>
|
||||
|
||||
//#include <tensorrt_llm/runtime/common.h>
|
||||
//#include <tensorrt_llm/executor/executor.h>
|
||||
//
|
||||
//namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
class TensorRtLlmBackendImpl {
|
||||
private:
|
||||
// tle::Executor executor;
|
||||
|
||||
public:
|
||||
TensorRtLlmBackendImpl(std::filesystem::path &engineFolder);
|
||||
};
|
||||
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
create_trtllm_backend(std::filesystem::path &engineFolder);
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_H
|
|
@ -0,0 +1,13 @@
|
|||
#include <spdlog/spdlog.h>
|
||||
#include <fmt/std.h>
|
||||
|
||||
#include "backend.h"
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(std::filesystem::path &engineFolder) {
|
||||
SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder);
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
huggingface::tgi::backends::create_trtllm_backend(std::filesystem::path &engineFolder) {
|
||||
return std::make_unique<huggingface::tgi::backends::TensorRtLlmBackendImpl>(engineFolder);
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
|
||||
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
|
||||
pub struct TensorRtLLmBackend {}
|
||||
|
||||
impl Backend for TensorRtLLmBackend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
//
|
||||
// Created by mfuntowicz on 6/30/24.
|
||||
//
|
||||
#include <filesystem>
|
||||
|
||||
|
||||
namespace huggingface::tgi::backends::trtllm {
|
||||
class TensorRtLlmBackend {
|
||||
public:
|
||||
TensorRtLlmBackend(std::filesystem::path engineFolder) {
|
||||
|
||||
}
|
||||
};
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
mod backend;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
unsafe extern "C++" {
|
||||
include!("backends/trtllm/include/backend.h");
|
||||
|
||||
type TensorRtLlmBackendImpl;
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue