122 lines
3.1 KiB
C
122 lines
3.1 KiB
C
|
//
|
||
|
// Created by Morgan Funtowicz on 6/30/24.
|
||
|
//
|
||
|
|
||
|
#ifndef TGI_TRTLLM_BACKEND_H
|
||
|
#define TGI_TRTLLM_BACKEND_H
|
||
|
|
||
|
#include <cmath>
|
||
|
#include <filesystem>
|
||
|
#include <span>
|
||
|
#include <vector>
|
||
|
|
||
|
#include <nlohmann/json.hpp>
|
||
|
|
||
|
#include <tensorrt_llm/runtime/common.h>
|
||
|
#include <tensorrt_llm/executor/executor.h>
|
||
|
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
||
|
|
||
|
using json = nlohmann::json;
|
||
|
namespace tle = tensorrt_llm::executor;
|
||
|
|
||
|
namespace huggingface::tgi::backends {
|
||
|
using RequestId = tle::IdType;
|
||
|
using TokenId = tle::TokenIdType;
|
||
|
|
||
|
/**
|
||
|
* Initialize all the components required by TRTLLM.
|
||
|
* It is required to call this function before attempting to load any engine
|
||
|
*/
|
||
|
void InitializeBackend();
|
||
|
|
||
|
/**
|
||
|
*
|
||
|
* @param config TensorRT-LLM configuration object
|
||
|
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||
|
* @return
|
||
|
*/
|
||
|
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||
|
|
||
|
/**
|
||
|
* Get the sampling configuration from the parameters provided by TGI
|
||
|
* @param topK
|
||
|
* @param topP
|
||
|
* @param temperature
|
||
|
* @param repetition_penalty
|
||
|
* @param frequency_penalty
|
||
|
* @param seed
|
||
|
* @return
|
||
|
*/
|
||
|
tle::SamplingConfig GetSamplingConfig(
|
||
|
uint32_t topK,
|
||
|
float_t topP,
|
||
|
float_t temperature,
|
||
|
float_t repetition_penalty,
|
||
|
float_t frequency_penalty,
|
||
|
uint64_t seed
|
||
|
);
|
||
|
|
||
|
/**
|
||
|
*
|
||
|
*/
|
||
|
class TensorRtLlmBackend {
|
||
|
private:
|
||
|
const json config;
|
||
|
tle::Executor executor;
|
||
|
|
||
|
public:
|
||
|
explicit TensorRtLlmBackend(
|
||
|
const std::filesystem::path &engineFolder,
|
||
|
const std::filesystem::path &executorWorker
|
||
|
);
|
||
|
|
||
|
/**
|
||
|
* Indicate if the backend is ready to accept incoming request
|
||
|
* @return true if ready, false otherwise
|
||
|
*/
|
||
|
[[nodiscard]] bool IsReady() const;
|
||
|
|
||
|
/**
|
||
|
* Query the executor for the number of token available for pulling
|
||
|
* @return
|
||
|
*/
|
||
|
[[nodiscard]] size_t NumResponsesReady() const;
|
||
|
|
||
|
/**
|
||
|
* Submit a new generation task to the executor
|
||
|
* @param tokens
|
||
|
* @param topK
|
||
|
* @param topP
|
||
|
* @param temperature
|
||
|
* @param repetition_penalty
|
||
|
* @param frequency_penalty
|
||
|
* @param seed
|
||
|
* @return Request id related to this generation for reference
|
||
|
*/
|
||
|
[[nodiscard]] RequestId Submit(
|
||
|
const std::vector<TokenId> &tokens,
|
||
|
int32_t topK,
|
||
|
float_t topP,
|
||
|
float_t temperature,
|
||
|
float_t repetition_penalty,
|
||
|
float_t frequency_penalty,
|
||
|
uint64_t seed
|
||
|
);
|
||
|
|
||
|
/**
|
||
|
*
|
||
|
* @param requestId The request id to poll the generation results
|
||
|
* @return
|
||
|
*/
|
||
|
std::vector<tle::Response> Poll(RequestId requestId);
|
||
|
|
||
|
/**
|
||
|
* Stop the underlying executor
|
||
|
*/
|
||
|
void Shutdown();
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
#endif //TGI_TRTLLM_BACKEND_H
|