// // Created by mfuntowicz on 7/11/24. // #ifndef TGI_TRTLLM_BACKEND_FFI_H #define TGI_TRTLLM_BACKEND_FFI_H #include <cmath> #include <cstddef> #include <memory> #include "backend.h" namespace huggingface::tgi::backends { class TensorRtLlmBackendImpl; } // Template to support returning error from TllmException back to Rust in a Result<> #include <tensorrt_llm/common/tllmException.h> namespace rust::behavior { template<typename Try, typename Fail> static void trycatch(Try &&func, Fail &&fail) noexcept try { func(); } catch (tensorrt_llm::common::TllmException &e) { fail(e.what()); } } #include "backends/trtllm/src/lib.rs.h" namespace huggingface::tgi::backends { class TensorRtLlmBackendImpl : public TensorRtLlmBackend { public: /*** * * @param engineFolder * @param executorWorker */ TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); /*** * * @param tokens * @param maxNewTokens * @param topK * @param topP * @param temperature * @param repetition_penalty * @param frequency_penalty * @param seed * @return */ [[nodiscard("returned request id should be used to refer to the request's generation result later on")]] uint64_t Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); /*** * * @return */ std::unique_ptr<std::vector<GenerationStep>> PullTokens(); }; /*** * * @param engineFolder * @return */ std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker); } #endif //TGI_TRTLLM_BACKEND_FFI_H