From 3a2698fb79e179a75cff589178379adec20ffe5e Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 19 Nov 2024 00:17:35 +0100 Subject: [PATCH] feat(backend): initial rewrite of the backend for simplicity --- backends/trtllm/csrc/backend.cpp | 38 ++++++++++++ backends/trtllm/csrc/backend.hpp | 100 +++++++++++++++++++++++++++++++ backends/trtllm/csrc/ffi.hpp | 60 +++++++++++++++++++ 3 files changed, 198 insertions(+) create mode 100644 backends/trtllm/csrc/backend.cpp create mode 100644 backends/trtllm/csrc/backend.hpp create mode 100644 backends/trtllm/csrc/ffi.hpp diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp new file mode 100644 index 00000000..2c681dd1 --- /dev/null +++ b/backends/trtllm/csrc/backend.cpp @@ -0,0 +1,38 @@ +#include +#include +#include "backend.hpp" + +#include + +namespace huggingface::tgi::backends::trtllm { + + size_t backend_t::num_tokens_ready() const noexcept { + return executor_.getNumResponsesReady(); + } + + std::expected + backend_t::submit(std::span token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept { + SPDLOG_DEBUG(FMT_STRING("Submitting {:d} tokens to the executor for scheduling"), token_ids.size()); + return executor_.enqueueRequest(tle::Request { + {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens + static_cast(generation_params.max_new_tokens), + true, + (tle::SamplingConfig) sampling_params, + tle::OutputConfig { /* returnLogProbs= */ true }, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + stop_words_ + }); + } + + std::vector backend_t::pull_tokens() noexcept { + return executor_.awaitResponses(); + } + + void backend_t::cancel(request_id_t request_id) noexcept { + SPDLOG_INFO(FMT_STRING("Cancelling request: {:d}"), request_id); + executor_.cancelRequest(request_id); + } +} \ No newline at end of file diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp new file mode 100644 index 00000000..9627f0ec --- /dev/null +++ b/backends/trtllm/csrc/backend.hpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include + +#include + +namespace huggingface::tgi::backends::trtllm { + namespace tle = tensorrt_llm::executor; + + using request_id_t = uint32_t; + using token_id_t = tle::TokenIdType; + + /** + * Represent the parameters used for generation + */ + struct generation_params_t { + uint32_t max_new_tokens; + }; + + /** + * Represent the parameters used to sample tokens from the logit distribution + */ + struct sampling_params_t { + uint32_t top_k; + float_t top_p; + float_t repetition_penalty; + float_t frequency_penalty; + float_t length_penalty; + float_t temperature; + uint64_t seed; + + explicit operator tle::SamplingConfig() const { + return tle::SamplingConfig { + 1, + top_k, + top_p, + std::nullopt, + std::nullopt, + std::nullopt, + seed, + temperature, + std::nullopt, + std::nullopt, + repetition_penalty, + std::nullopt, + frequency_penalty, + length_penalty + }; + } + }; + + /** + * + */ + class backend_exception_t: std::exception {}; + + /** + * + */ + class backend_t { + private: + tle::Executor executor_; + std::list> stop_words_; + + public: + /** + * Submit a new request to the executor + * @param token_ids + * @param generation_params + * @param sampling_params + * @return Either newly submitted request's id or the error why it failed to submit + */ + [[nodiscard("Discarded executor request_id needs to be assigned")]] + std::expected + submit(std::span token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept; + + /** + * Query the number of tokens available across all in-flight generations + * @return + */ + [[nodiscard("Pulling out the number of tokens")]] + size_t num_tokens_ready() const noexcept; + + /** + * Pull out newly generated tokens from the executor + * @return + */ + [[nodiscard("")]] + std::vector pull_tokens() noexcept; + + /** + * Cancel the specified request on the executor' set + * @param request_id Request's Identifier to remove from the in-flight executor + */ + void cancel(request_id_t) noexcept; + }; +} \ No newline at end of file diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp new file mode 100644 index 00000000..d72b26db --- /dev/null +++ b/backends/trtllm/csrc/ffi.hpp @@ -0,0 +1,60 @@ + +#include + +namespace rust::behavior { + template + static void trycatch(Try &&func, Fail &&fail) noexcept try { + func(); + } catch (tensorrt_llm::common::TllmException &e) { + fail(e.what()); + } +} + +#include + +namespace huggingface::tgi::backends::trtllm { + + class tensorrt_llm_backend_t { + private: + backend_t inner_; + + public: + tensorrt_llm_backend_t(std::filesystem::path &engine_folder): inner_(engine_folder) {} + + size_t num_tokens_ready() const noexcept { + return inner_.num_tokens_ready(); + } + + request_id_t submit( + rust::Slice tokens, + uint32_t max_new_tokens, + uint32_t top_k, + float_t top_p, + float_t temperature, + float_t repetition_penalty, + float_t frequency_penalty, + uint64_t seed + ) { + // Submit the request to the executor and get back a potential request_id used to track request status + const auto maybe_request_id = inner_.submit( + {tokens_.data(), tokens.size()}, + {max_new_tokens}, + {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} + ); + + // If we do have a value, let's return the request_id + if(maybe_request_id.has_value()) [[likely]] { + return *maybe_request_id; + } else { + + } + } + + void cancel(request_id_t requestId) noexcept { + SPDLOG + inner_.cancel(requestId); + } + }; + + +} \ No newline at end of file