Initial setup for CXX binding to TRTLLM

This commit is contained in:
Morgan Funtowicz 2024-06-30 23:37:20 +02:00
parent 230f2a415a
commit dc402dc9ac
10 changed files with 286 additions and 1 deletions

101
Cargo.lock generated
View File

@ -565,6 +565,25 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
[[package]]
name = "cmake"
version = "0.1.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130"
dependencies = [
"cc",
]
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]]
name = "color_quant"
version = "1.1.0"
@ -684,6 +703,50 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "cxx"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82"
dependencies = [
"cc",
"cxxbridge-flags",
"cxxbridge-macro",
"link-cplusplus",
]
[[package]]
name = "cxx-build"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e"
dependencies = [
"cc",
"codespan-reporting",
"once_cell",
"proc-macro2",
"quote",
"scratch",
"syn 2.0.66",
]
[[package]]
name = "cxxbridge-flags"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd"
[[package]]
name = "cxxbridge-macro"
version = "1.0.124"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
]
[[package]]
name = "darling"
version = "0.20.9"
@ -1615,6 +1678,15 @@ dependencies = [
"libc",
]
[[package]]
name = "link-cplusplus"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9"
dependencies = [
"cc",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.14"
@ -3040,6 +3112,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scratch"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152"
[[package]]
name = "sct"
version = "0.7.1"
@ -3367,6 +3445,29 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]]
name = "text-generation-backends-trtllm"
version = "2.0.5-dev0"
dependencies = [
"async-stream",
"async-trait",
"cmake",
"cxx",
"cxx-build",
"text-generation-router",
"tokio",
"tokio-stream",
]
[[package]]
name = "text-generation-launcher"
version = "2.0.5-dev0"

View File

@ -5,7 +5,7 @@ members = [
# "backends/client",
"backends/grpc-metadata",
"launcher"
]
, "backends/trtllm"]
resolver = "2"
[workspace.package]

View File

@ -0,0 +1,57 @@
cmake_minimum_required(VERSION 3.26)
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20)
include(FetchContent)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
#### External dependencies ####
# Logging library (SPDLOG)
set(SPDLOG_USE_FMT ON)
fetchcontent_declare(
spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v2.x
)
fetchcontent_makeavailable(spdlog)
# TensorRT-LLM
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git
GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c
)
fetchcontent_populate(trtllm)
include_directories("${trtllm_SOURCE_DIR}/cpp/include")
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
# TGI TRTLLM Backend definition
add_library(tgi_trtllm_backend_impl include/backend.h lib/backend.cpp)
target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
target_link_libraries(tgi_trtllm_backend_impl PRIVATE spdlog)
#### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
message(STATUS "Building tests")
FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2
GIT_TAG v3.6.0
)
FetchContent_MakeAvailable(Catch2)
add_executable(tgi_trtllm_backend_tests)
target_link_libraries(tests PRIVATE Catch2::Catch2::Catch2WithMain)
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
include(CTest)
include(Catch)
catch_discover_tests(tests)
endif ()

View File

@ -0,0 +1,18 @@
[package]
name = "text-generation-backends-trtllm"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
cxx = "1.0"
text-generation-router = { path = "../../router" }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
[build-dependencies]
cmake = "0.1"
cxx-build = "1.0"

19
backends/trtllm/build.rs Normal file
View File

@ -0,0 +1,19 @@
use cxx_build::CFG;
fn main() {
let backend_path = cmake::Config::new(".")
.uses_cxx11()
.generator("Ninja")
.build_target("tgi_trtllm_backend_impl")
.build();
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.file("src/ffi.cpp")
.std("c++20")
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
// println!("cargo:rustc-link-lib=tgi_trtllm_backend_impl");
}

View File

@ -0,0 +1,33 @@
//
// Created by Morgan Funtowicz on 6/30/24.
//
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H
#include <filesystem>
//#include <tensorrt_llm/runtime/common.h>
//#include <tensorrt_llm/executor/executor.h>
//
//namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl {
private:
// tle::Executor executor;
public:
TensorRtLlmBackendImpl(std::filesystem::path &engineFolder);
};
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
create_trtllm_backend(std::filesystem::path &engineFolder);
}
#endif //TGI_TRTLLM_BACKEND_H

View File

@ -0,0 +1,13 @@
#include <spdlog/spdlog.h>
#include <fmt/std.h>
#include "backend.h"
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(std::filesystem::path &engineFolder) {
SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder);
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::create_trtllm_backend(std::filesystem::path &engineFolder) {
return std::make_unique<huggingface::tgi::backends::TensorRtLlmBackendImpl>(engineFolder);
}

View File

@ -0,0 +1,19 @@
use tokio_stream::wrappers::UnboundedReceiverStream;
use text_generation_router::infer::{Backend, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
pub struct TensorRtLLmBackend {}
impl Backend for TensorRtLLmBackend {
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
todo!()
}
async fn health(&self, current_health: bool) -> bool {
todo!()
}
}

View File

@ -0,0 +1,14 @@
//
// Created by mfuntowicz on 6/30/24.
//
#include <filesystem>
namespace huggingface::tgi::backends::trtllm {
class TensorRtLlmBackend {
public:
TensorRtLlmBackend(std::filesystem::path engineFolder) {
}
};
}

View File

@ -0,0 +1,11 @@
mod backend;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
unsafe extern "C++" {
include!("backends/trtllm/include/backend.h");
type TensorRtLlmBackendImpl;
}
}