implement the Stream method to send new tokens through a callback
This commit is contained in:
parent
09292b06a0
commit
13eabfabcb
|
@ -20,6 +20,8 @@ namespace tle = tensorrt_llm::executor;
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
|
using TokenStreamingCallback = void(tle::TokenIdType);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize all the components required by TRTLLM.
|
* Initialize all the components required by TRTLLM.
|
||||||
* It is required to call this function before attempting to load any engine
|
* It is required to call this function before attempting to load any engine
|
||||||
|
@ -28,9 +30,8 @@ namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param config
|
* @param config TensorRT-LLM configuration object
|
||||||
* @param workerPath
|
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||||||
* @param channel
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||||
|
@ -58,7 +59,7 @@ namespace huggingface::tgi::backends {
|
||||||
}
|
}
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
* Submit a new generation task to the executor
|
||||||
* @param tokens
|
* @param tokens
|
||||||
* @param maxNewTokens
|
* @param maxNewTokens
|
||||||
* @param topK
|
* @param topK
|
||||||
|
@ -69,7 +70,7 @@ namespace huggingface::tgi::backends {
|
||||||
* @param frequencyPenalty
|
* @param frequencyPenalty
|
||||||
* @param seed
|
* @param seed
|
||||||
* @param nTopTokens
|
* @param nTopTokens
|
||||||
* @return
|
* @return Request id related to this generation for reference
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] tle::IdType Submit(
|
[[nodiscard]] tle::IdType Submit(
|
||||||
const std::vector<tle::TokenIdType> &tokens,
|
const std::vector<tle::TokenIdType> &tokens,
|
||||||
|
@ -85,11 +86,13 @@ namespace huggingface::tgi::backends {
|
||||||
);
|
);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
* Unroll the token generation until end of stream is reached.
|
||||||
* @param reqId
|
* Every generated token is streamed back through the provided callback for further processing
|
||||||
* @return
|
* @param reqId The request id to unroll
|
||||||
|
* @param cb The callback to stream token back
|
||||||
|
* @return Global number of generated tokens for this request id
|
||||||
*/
|
*/
|
||||||
std::vector<tle::Response> Poll(tle::IdType reqId);
|
size_t Stream(tle::IdType reqId, const std::function<TokenStreamingCallback>& cb);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ void huggingface::tgi::backends::InitializeBackend() {
|
||||||
initTrtLlmPlugins();
|
initTrtLlmPlugins();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]]
|
||||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||||
tle::ExecutorConfig execConfig(1);
|
tle::ExecutorConfig execConfig(1);
|
||||||
|
|
||||||
|
@ -64,6 +65,7 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>());
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string&>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const std::vector<tle::TokenIdType> &tokens,
|
const std::vector<tle::TokenIdType> &tokens,
|
||||||
const int32_t maxNewTokens,
|
const int32_t maxNewTokens,
|
||||||
|
@ -76,8 +78,8 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
std::optional<uint32_t> seed,
|
std::optional<uint32_t> seed,
|
||||||
std::optional<uint32_t> nTopTokens
|
std::optional<uint32_t> nTopTokens
|
||||||
) {
|
) {
|
||||||
spdlog::debug(
|
SPDLOG_DEBUG(
|
||||||
FMT_STRING("Submitting inference over {:d} tokens to the executor {:d}"),
|
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||||
tokens.size(),
|
tokens.size(),
|
||||||
executor.getLatestIterationStats().back().numActiveRequests
|
executor.getLatestIterationStats().back().numActiveRequests
|
||||||
);
|
);
|
||||||
|
@ -103,8 +105,31 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
return executor.enqueueRequest(request);
|
return executor.enqueueRequest(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType reqId) {
|
size_t huggingface::tgi::backends::TensorRtLlmBackend::Stream(const tle::IdType reqId, const std::function<TokenStreamingCallback>& cb) {
|
||||||
SPDLOG_DEBUG(FMT_STRING("Polling request {:d}"), reqId);
|
bool isFinal = false;
|
||||||
|
size_t generatedTokens = 0;
|
||||||
|
|
||||||
|
do {
|
||||||
const auto responses = executor.awaitResponses(reqId);
|
const auto responses = executor.awaitResponses(reqId);
|
||||||
return responses;
|
for (const auto &response: responses){
|
||||||
|
if(response.hasError()) {
|
||||||
|
SPDLOG_WARN("Caught error during generation: {}", response.getErrorMsg());
|
||||||
|
isFinal = true;
|
||||||
|
} else {
|
||||||
|
const auto generation = response.getResult();
|
||||||
|
const auto token = generation.outputTokenIds[0][0];
|
||||||
|
|
||||||
|
// Update the end of stream detection and overall number of generated tokens
|
||||||
|
isFinal = generation.isFinal;
|
||||||
|
++generatedTokens;
|
||||||
|
|
||||||
|
// Send the token back through the callback function for further processing
|
||||||
|
cb(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} while(!isFinal);
|
||||||
|
|
||||||
|
// Return the number of generated tokens
|
||||||
|
return generatedTokens;
|
||||||
}
|
}
|
Loading…
Reference in New Issue