From dc402dc9acef0c3747c85d8cc38bdaa6180651dd Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sun, 30 Jun 2024 23:37:20 +0200 Subject: [PATCH] Initial setup for CXX binding to TRTLLM --- Cargo.lock | 101 ++++++++++++++++++++++++++++++ Cargo.toml | 2 +- backends/trtllm/CMakeLists.txt | 57 +++++++++++++++++ backends/trtllm/Cargo.toml | 18 ++++++ backends/trtllm/build.rs | 19 ++++++ backends/trtllm/include/backend.h | 33 ++++++++++ backends/trtllm/lib/backend.cpp | 13 ++++ backends/trtllm/src/backend.rs | 19 ++++++ backends/trtllm/src/ffi.cpp | 14 +++++ backends/trtllm/src/lib.rs | 11 ++++ 10 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 backends/trtllm/CMakeLists.txt create mode 100644 backends/trtllm/Cargo.toml create mode 100644 backends/trtllm/build.rs create mode 100644 backends/trtllm/include/backend.h create mode 100644 backends/trtllm/lib/backend.cpp create mode 100644 backends/trtllm/src/backend.rs create mode 100644 backends/trtllm/src/ffi.cpp create mode 100644 backends/trtllm/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index ce911ce7..27404e41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -565,6 +565,25 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "cmake" +version = "0.1.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +dependencies = [ + "cc", +] + +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -684,6 +703,50 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "cxx" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.66", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "darling" version = "0.20.9" @@ -1615,6 +1678,15 @@ dependencies = [ "libc", ] +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -3040,6 +3112,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "sct" version = "0.7.1" @@ -3367,6 +3445,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "text-generation-backends-trtllm" +version = "2.0.5-dev0" +dependencies = [ + "async-stream", + "async-trait", + "cmake", + "cxx", + "cxx-build", + "text-generation-router", + "tokio", + "tokio-stream", +] + [[package]] name = "text-generation-launcher" version = "2.0.5-dev0" diff --git a/Cargo.toml b/Cargo.toml index 28ded514..e91f2609 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ # "backends/client", "backends/grpc-metadata", "launcher" -] +, "backends/trtllm"] resolver = "2" [workspace.package] diff --git a/backends/trtllm/CMakeLists.txt b/backends/trtllm/CMakeLists.txt new file mode 100644 index 00000000..ef8700d9 --- /dev/null +++ b/backends/trtllm/CMakeLists.txt @@ -0,0 +1,57 @@ +cmake_minimum_required(VERSION 3.26) + +project(tgi-trtllm-backend VERSION 1.0.0) +set(CMAKE_CXX_STANDARD 20) + +include(FetchContent) + +option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF) + + +#### External dependencies #### +# Logging library (SPDLOG) +set(SPDLOG_USE_FMT ON) +fetchcontent_declare( + spdlog + GIT_REPOSITORY https://github.com/gabime/spdlog.git + GIT_TAG v2.x +) +fetchcontent_makeavailable(spdlog) + +# TensorRT-LLM +fetchcontent_declare( + trtllm + GIT_REPOSITORY https://github.com/nvidia/tensorrt-llm.git + GIT_TAG 9691e12bce7ae1c126c435a049eb516eb119486c +) +fetchcontent_populate(trtllm) +include_directories("${trtllm_SOURCE_DIR}/cpp/include") +message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}") + +# TGI TRTLLM Backend definition +add_library(tgi_trtllm_backend_impl include/backend.h lib/backend.cpp) + +target_include_directories(tgi_trtllm_backend_impl PRIVATE + $ + $ +) +target_link_libraries(tgi_trtllm_backend_impl PRIVATE spdlog) + +#### Unit Tests #### +if (${TGI_TRTLLM_BACKEND_BUILD_TESTS}) + message(STATUS "Building tests") + FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2 + GIT_TAG v3.6.0 + ) + FetchContent_MakeAvailable(Catch2) + + add_executable(tgi_trtllm_backend_tests) + target_link_libraries(tests PRIVATE Catch2::Catch2::Catch2WithMain) + + list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras) + include(CTest) + include(Catch) + catch_discover_tests(tests) +endif () \ No newline at end of file diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml new file mode 100644 index 00000000..39369c48 --- /dev/null +++ b/backends/trtllm/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "text-generation-backends-trtllm" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +async-trait = "0.1.74" +async-stream = "0.3.5" +cxx = "1.0" +text-generation-router = { path = "../../router" } +tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokio-stream = "0.1.14" + +[build-dependencies] +cmake = "0.1" +cxx-build = "1.0" \ No newline at end of file diff --git a/backends/trtllm/build.rs b/backends/trtllm/build.rs new file mode 100644 index 00000000..32d0f943 --- /dev/null +++ b/backends/trtllm/build.rs @@ -0,0 +1,19 @@ +use cxx_build::CFG; + +fn main() { + let backend_path = cmake::Config::new(".") + .uses_cxx11() + .generator("Ninja") + .build_target("tgi_trtllm_backend_impl") + .build(); + + CFG.include_prefix = "backends/trtllm"; + cxx_build::bridge("src/lib.rs") + .file("src/ffi.cpp") + .std("c++20") + .compile("tgi_trtllm_backend"); + + println!("cargo:rerun-if-changed=include/backend.h"); + println!("cargo:rerun-if-changed=lib/backend.cpp"); + // println!("cargo:rustc-link-lib=tgi_trtllm_backend_impl"); +} diff --git a/backends/trtllm/include/backend.h b/backends/trtllm/include/backend.h new file mode 100644 index 00000000..d2b7f853 --- /dev/null +++ b/backends/trtllm/include/backend.h @@ -0,0 +1,33 @@ +// +// Created by Morgan Funtowicz on 6/30/24. +// + +#ifndef TGI_TRTLLM_BACKEND_H +#define TGI_TRTLLM_BACKEND_H + +#include + +//#include +//#include +// +//namespace tle = tensorrt_llm::executor; + +namespace huggingface::tgi::backends { + class TensorRtLlmBackendImpl { + private: +// tle::Executor executor; + + public: + TensorRtLlmBackendImpl(std::filesystem::path &engineFolder); + }; + + /*** + * + * @param engineFolder + * @return + */ + std::unique_ptr + create_trtllm_backend(std::filesystem::path &engineFolder); +} + +#endif //TGI_TRTLLM_BACKEND_H diff --git a/backends/trtllm/lib/backend.cpp b/backends/trtllm/lib/backend.cpp new file mode 100644 index 00000000..ec447fbf --- /dev/null +++ b/backends/trtllm/lib/backend.cpp @@ -0,0 +1,13 @@ +#include +#include + +#include "backend.h" + +huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(std::filesystem::path &engineFolder) { + SPDLOG_INFO(FMT_STRING("Loading engines from {}"), engineFolder); +} + +std::unique_ptr +huggingface::tgi::backends::create_trtllm_backend(std::filesystem::path &engineFolder) { + return std::make_unique(engineFolder); +} diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs new file mode 100644 index 00000000..3b3b5bce --- /dev/null +++ b/backends/trtllm/src/backend.rs @@ -0,0 +1,19 @@ +use tokio_stream::wrappers::UnboundedReceiverStream; + +use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::validation::ValidGenerateRequest; + +pub struct TensorRtLLmBackend {} + +impl Backend for TensorRtLLmBackend { + fn schedule( + &self, + request: ValidGenerateRequest, + ) -> Result>, InferError> { + todo!() + } + + async fn health(&self, current_health: bool) -> bool { + todo!() + } +} diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp new file mode 100644 index 00000000..70f23394 --- /dev/null +++ b/backends/trtllm/src/ffi.cpp @@ -0,0 +1,14 @@ +// +// Created by mfuntowicz on 6/30/24. +// +#include + + +namespace huggingface::tgi::backends::trtllm { + class TensorRtLlmBackend { + public: + TensorRtLlmBackend(std::filesystem::path engineFolder) { + + } + }; +} \ No newline at end of file diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs new file mode 100644 index 00000000..0ac550cb --- /dev/null +++ b/backends/trtllm/src/lib.rs @@ -0,0 +1,11 @@ +mod backend; + +#[cxx::bridge(namespace = "huggingface::tgi::backends")] +mod ffi { + unsafe extern "C++" { + include!("backends/trtllm/include/backend.h"); + + type TensorRtLlmBackendImpl; + + } +}