oops missing c++ backend definitions
This commit is contained in:
parent
7784a21d48
commit
a01cd030d4
|
@ -58,6 +58,12 @@ namespace huggingface::tgi::backends {
|
|||
*/
|
||||
[[nodiscard]] bool IsReady() const;
|
||||
|
||||
/***
|
||||
* Query the executor for the number of token available for pulling
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard]] size_t NumResponsesReady() const;
|
||||
|
||||
/***
|
||||
* Submit a new generation task to the executor
|
||||
* @param tokens
|
||||
|
@ -70,7 +76,6 @@ namespace huggingface::tgi::backends {
|
|||
*/
|
||||
[[nodiscard]] RequestId Submit(
|
||||
const std::vector<TokenId> &tokens,
|
||||
int32_t maxNewTokens,
|
||||
int32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
|
@ -84,15 +89,6 @@ namespace huggingface::tgi::backends {
|
|||
*/
|
||||
std::vector<tle::Response> Poll(RequestId requestId);
|
||||
|
||||
/***
|
||||
* Unroll the token generation until end of stream is reached.
|
||||
* Every generated token is streamed back through the provided callback for further processing
|
||||
* @param reqId The request id to unroll
|
||||
* @param cb The callback to stream token back
|
||||
* @return Global number of generated tokens for this request id
|
||||
*/
|
||||
uint32_t Stream(RequestId reqId, std::function<TokenStreamingCallback> &cb);
|
||||
|
||||
/***
|
||||
* Stop the underlying executor
|
||||
*/
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include <fstream>
|
||||
|
||||
#include <nvml.h>
|
||||
#include "fmt/format.h"
|
||||
#include <fmt/ranges.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "backend.h"
|
||||
|
@ -72,20 +72,31 @@ bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
|||
return executor.canEnqueueRequests();
|
||||
}
|
||||
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||
return executor.getNumResponsesReady();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const std::vector<tle::TokenIdType> &tokens,
|
||||
const int32_t maxNewTokens,
|
||||
const int32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const uint64_t seed
|
||||
) {
|
||||
SPDLOG_DEBUG(
|
||||
#ifndef NDEBUG
|
||||
SPDLOG_INFO(
|
||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||
tokens.size(),
|
||||
executor.getLatestIterationStats().back().numActiveRequests
|
||||
);
|
||||
#else
|
||||
SPDLOG_INFO(
|
||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
||||
fmt::join(tokens, ", "),
|
||||
executor.getLatestIterationStats().back().numActiveRequests
|
||||
);
|
||||
#endif
|
||||
|
||||
const auto sampling = tle::SamplingConfig{
|
||||
1,
|
||||
|
@ -100,40 +111,13 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
|||
std::nullopt,
|
||||
};
|
||||
const auto output = tle::OutputConfig{false, false, false};
|
||||
return executor.enqueueRequest(tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||
}
|
||||
|
||||
uint32_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId,
|
||||
std::function<TokenStreamingCallback> &cb) {
|
||||
bool isFinal = false;
|
||||
size_t generatedTokens = 0;
|
||||
|
||||
do {
|
||||
const auto responses = executor.awaitResponses(reqId);
|
||||
for (const auto &response: responses) {
|
||||
if (response.hasError()) {
|
||||
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
|
||||
isFinal = true;
|
||||
} else {
|
||||
const auto generation = response.getResult();
|
||||
const auto token = generation.outputTokenIds[0][0];
|
||||
|
||||
// Update the end of stream detection and overall number of generated tokens
|
||||
isFinal = generation.isFinal;
|
||||
++generatedTokens;
|
||||
|
||||
// Send the token back through the callback function for further processing
|
||||
cb(token);
|
||||
}
|
||||
}
|
||||
|
||||
} while (!isFinal);
|
||||
|
||||
// Return the number of generated tokens
|
||||
return generatedTokens;
|
||||
return executor.enqueueRequest(
|
||||
tle::Request{tokens, std::numeric_limits<tle::SizeType32>::max(), true, sampling, output});
|
||||
}
|
||||
|
||||
[[nodiscard("Generated tokens result must be used")]]
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
||||
SPDLOG_INFO("Polling status for request {}", requestId);
|
||||
return executor.awaitResponses(requestId);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue