feat(backend): initial rewrite of the backend for simplicity
This commit is contained in:
parent
a80c346f72
commit
f24e9fa2b9
|
@ -0,0 +1,38 @@
|
||||||
|
#include <ranges>
|
||||||
|
#include <utility>
|
||||||
|
#include "backend.hpp"
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends::trtllm {
|
||||||
|
|
||||||
|
size_t backend_t::num_tokens_ready() const noexcept {
|
||||||
|
return executor_.getNumResponsesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::expected<request_id_t, backend_exception_t>
|
||||||
|
backend_t::submit(std::span<tle::TokenIdType> token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept {
|
||||||
|
SPDLOG_DEBUG(FMT_STRING("Submitting {:d} tokens to the executor for scheduling"), token_ids.size());
|
||||||
|
return executor_.enqueueRequest(tle::Request {
|
||||||
|
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
|
||||||
|
static_cast<tle::SizeType32>(generation_params.max_new_tokens),
|
||||||
|
true,
|
||||||
|
(tle::SamplingConfig) sampling_params,
|
||||||
|
tle::OutputConfig { /* returnLogProbs= */ true },
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
stop_words_
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<tle::Response> backend_t::pull_tokens() noexcept {
|
||||||
|
return executor_.awaitResponses();
|
||||||
|
}
|
||||||
|
|
||||||
|
void backend_t::cancel(request_id_t request_id) noexcept {
|
||||||
|
SPDLOG_INFO(FMT_STRING("Cancelling request: {:d}"), request_id);
|
||||||
|
executor_.cancelRequest(request_id);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <exception>
|
||||||
|
#include <expected>
|
||||||
|
#include <list>
|
||||||
|
#include <span>
|
||||||
|
|
||||||
|
#include <tensorrt_llm/executor/executor.h>
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends::trtllm {
|
||||||
|
namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
|
using request_id_t = uint32_t;
|
||||||
|
using token_id_t = tle::TokenIdType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent the parameters used for generation
|
||||||
|
*/
|
||||||
|
struct generation_params_t {
|
||||||
|
uint32_t max_new_tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represent the parameters used to sample tokens from the logit distribution
|
||||||
|
*/
|
||||||
|
struct sampling_params_t {
|
||||||
|
uint32_t top_k;
|
||||||
|
float_t top_p;
|
||||||
|
float_t repetition_penalty;
|
||||||
|
float_t frequency_penalty;
|
||||||
|
float_t length_penalty;
|
||||||
|
float_t temperature;
|
||||||
|
uint64_t seed;
|
||||||
|
|
||||||
|
explicit operator tle::SamplingConfig() const {
|
||||||
|
return tle::SamplingConfig {
|
||||||
|
1,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
seed,
|
||||||
|
temperature,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt,
|
||||||
|
repetition_penalty,
|
||||||
|
std::nullopt,
|
||||||
|
frequency_penalty,
|
||||||
|
length_penalty
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class backend_exception_t: std::exception {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class backend_t {
|
||||||
|
private:
|
||||||
|
tle::Executor executor_;
|
||||||
|
std::list<std::vector<int32_t>> stop_words_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* Submit a new request to the executor
|
||||||
|
* @param token_ids
|
||||||
|
* @param generation_params
|
||||||
|
* @param sampling_params
|
||||||
|
* @return Either newly submitted request's id or the error why it failed to submit
|
||||||
|
*/
|
||||||
|
[[nodiscard("Discarded executor request_id needs to be assigned")]]
|
||||||
|
std::expected<request_id_t, backend_exception_t>
|
||||||
|
submit(std::span<token_id_t> token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query the number of tokens available across all in-flight generations
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard("Pulling out the number of tokens")]]
|
||||||
|
size_t num_tokens_ready() const noexcept;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pull out newly generated tokens from the executor
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
[[nodiscard("")]]
|
||||||
|
std::vector<tle::Response> pull_tokens() noexcept;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cancel the specified request on the executor' set
|
||||||
|
* @param request_id Request's Identifier to remove from the in-flight executor
|
||||||
|
*/
|
||||||
|
void cancel(request_id_t) noexcept;
|
||||||
|
};
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
|
||||||
|
#include <tensorrt_llm/common/tllmException.h>
|
||||||
|
|
||||||
|
namespace rust::behavior {
|
||||||
|
template<typename Try, typename Fail>
|
||||||
|
static void trycatch(Try &&func, Fail &&fail) noexcept try {
|
||||||
|
func();
|
||||||
|
} catch (tensorrt_llm::common::TllmException &e) {
|
||||||
|
fail(e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#include <backend.hpp>
|
||||||
|
|
||||||
|
namespace huggingface::tgi::backends::trtllm {
|
||||||
|
|
||||||
|
class tensorrt_llm_backend_t {
|
||||||
|
private:
|
||||||
|
backend_t inner_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
tensorrt_llm_backend_t(std::filesystem::path &engine_folder): inner_(engine_folder) {}
|
||||||
|
|
||||||
|
size_t num_tokens_ready() const noexcept {
|
||||||
|
return inner_.num_tokens_ready();
|
||||||
|
}
|
||||||
|
|
||||||
|
request_id_t submit(
|
||||||
|
rust::Slice<const uint32_t> tokens,
|
||||||
|
uint32_t max_new_tokens,
|
||||||
|
uint32_t top_k,
|
||||||
|
float_t top_p,
|
||||||
|
float_t temperature,
|
||||||
|
float_t repetition_penalty,
|
||||||
|
float_t frequency_penalty,
|
||||||
|
uint64_t seed
|
||||||
|
) {
|
||||||
|
// Submit the request to the executor and get back a potential request_id used to track request status
|
||||||
|
const auto maybe_request_id = inner_.submit(
|
||||||
|
{tokens_.data(), tokens.size()},
|
||||||
|
{max_new_tokens},
|
||||||
|
{top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}
|
||||||
|
);
|
||||||
|
|
||||||
|
// If we do have a value, let's return the request_id
|
||||||
|
if(maybe_request_id.has_value()) [[likely]] {
|
||||||
|
return *maybe_request_id;
|
||||||
|
} else {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cancel(request_id_t requestId) noexcept {
|
||||||
|
SPDLOG
|
||||||
|
inner_.cancel(requestId);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue