feat(backend): entirely rewrite backend
This commit is contained in:
parent
611590440d
commit
b98c635781
|
@ -16,55 +16,89 @@
|
|||
|
||||
namespace huggingface::tgi::backends::llamacpp {
|
||||
|
||||
std::unique_ptr<llama_sampler> SamplingParams::IntoLlamaSampler(const llama_model *pModel) const {
|
||||
void llama_batch_fill_prompt(llama_batch &batch, std::span<const llama_token> input_tokens) {
|
||||
for (auto i = 0; i < input_tokens.size(); ++i) {
|
||||
batch.token[i] = input_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
++batch.n_tokens;
|
||||
}
|
||||
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_sampler> sampling_params_t::into_llama_sampler(const llama_model *model) const {
|
||||
auto *pSampler = llama_sampler_chain_init({.no_perf = false});
|
||||
|
||||
// Penalties
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_penalties(
|
||||
llama_n_vocab(pModel),
|
||||
llama_token_eos(pModel),
|
||||
llama_token_nl(pModel),
|
||||
llama_n_vocab(model),
|
||||
llama_token_eos(model),
|
||||
llama_token_nl(model),
|
||||
0.0f,
|
||||
repetitionPenalty,
|
||||
frequencyPenalty,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
0.0f,
|
||||
false,
|
||||
false
|
||||
));
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast<int32_t>(topK)));
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_top_k(static_cast<int32_t>(top_k)));
|
||||
|
||||
if (0 < topP && topP < 1) {
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(topP, 1));
|
||||
if (0 < top_p && top_p < 1) {
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_top_p(top_p, 1));
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
|
||||
return std::unique_ptr<llama_sampler>(pSampler);
|
||||
}
|
||||
|
||||
Worker::Worker(std::shared_ptr<llama_model> pModel, const llama_context_params ¶ms)
|
||||
: mModel_(pModel), mParams_(params) {
|
||||
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms)
|
||||
: mModel_(model), mParams_(params) {
|
||||
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
char modelName[256];
|
||||
llama_model_meta_val_str(pModel.get(), "general.name", modelName, sizeof(modelName));
|
||||
llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName));
|
||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||
#endif
|
||||
}
|
||||
|
||||
void Worker::Loop(std::atomic_flag &running, std::atomic_uint8_t &waiting, std::queue<SamplingParams> &backlog) {
|
||||
void worker_t::loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const {
|
||||
auto *context = llama_new_context_with_model(mModel_.get(), mParams_);
|
||||
|
||||
while (running.test(std::memory_order_acquire)) {
|
||||
if (waiting.load(std::memory_order_acquire) > 0) {
|
||||
--waiting;
|
||||
while (!driver.stop_requested()) {
|
||||
const auto generation_context = backlog.front();
|
||||
|
||||
auto request = backlog.front();
|
||||
auto sampler = request.IntoLlamaSampler(mModel_.get());
|
||||
generate(context, generation_context, std::nullopt);
|
||||
backlog.pop();
|
||||
|
||||
SPDLOG_DEBUG("Processed request ({:d} remaining)", backlog.size());
|
||||
}
|
||||
|
||||
llama_free(context);
|
||||
}
|
||||
|
||||
size_t worker_t::generate(
|
||||
llama_context *context,
|
||||
const generation_context_t &generation_context,
|
||||
const std::optional<llama_decode_callback> &callback) const {
|
||||
// Store information about context and generation size
|
||||
auto prompt_length = std::ssize(generation_context.input_tokens);
|
||||
auto max_new_tokens = generation_context.generation_params.max_new_tokens;
|
||||
|
||||
// Convert sampling params to what llama.cpp is looking for
|
||||
auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get());
|
||||
|
||||
// Setup the prompt
|
||||
auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end());
|
||||
auto batch = llama_batch_get_one(copy.data(), copy.size());
|
||||
|
||||
// Retrieve decoding context
|
||||
auto batch = llama_batch_get_one(tokens.data(), tokens.size());
|
||||
// Decode
|
||||
for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < 1; ++nDecoded) {
|
||||
auto n_decoded_tokens = 0;
|
||||
for (bool generating = true; generating && n_decoded_tokens < max_new_tokens; ++n_decoded_tokens) {
|
||||
const auto callback_ = callback.value_or(llama_void_callback);
|
||||
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
const auto start = std::chrono::steady_clock::now();
|
||||
const auto status = llama_decode(context, batch);
|
||||
|
@ -74,27 +108,64 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
#else
|
||||
const auto status = llama_decode(ctx, batch);
|
||||
#endif
|
||||
batch.n_tokens = 0;
|
||||
if (LLAMA_SUCCESS(status)) {
|
||||
// Sample the new token
|
||||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||
generated.emplace_back(new_token_id);
|
||||
generating = !llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
auto is_eos = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
|
||||
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
|
||||
generating = !is_eos;
|
||||
|
||||
// Bubble up the generated token if a callback is provided
|
||||
std::invoke(std::forward<const llama_decode_callback>(callback_), new_token_id, is_eos);
|
||||
|
||||
// Next iteration
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
}
|
||||
}
|
||||
|
||||
backlog.pop();
|
||||
|
||||
}
|
||||
return n_decoded_tokens;
|
||||
}
|
||||
|
||||
llama_free(context);
|
||||
|
||||
backend_base_t::backend_base_t(llama_model *model) : mModel_(model, llama_free_model) { llama_backend_init(); }
|
||||
|
||||
backend_base_t::~backend_base_t() { llama_backend_free(); }
|
||||
|
||||
std::expected<std::vector<llama_token>, backend_error_t> backend_base_t::generate(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback
|
||||
) {
|
||||
// TODO: Should we provide a way to change this value?
|
||||
auto generated = std::vector<llama_token>(2 << 8);
|
||||
|
||||
auto nTokensGenerated = generate(tokens, generated, generation_params, sampling_params, callback);
|
||||
if (nTokensGenerated.has_value())
|
||||
generated.resize(*nTokensGenerated);
|
||||
return generated;
|
||||
}
|
||||
|
||||
huggingface::tgi::backends::llamacpp::BackendBase::BackendBase(llama_model *model)
|
||||
: mModel_(model, llama_free_model) { llama_backend_init(); }
|
||||
|
||||
BackendBase::~BackendBase() { llama_backend_free(); }
|
||||
/** Single worker_t Backend impl **/
|
||||
|
||||
single_worker_backend_t::single_worker_backend_t(llama_model *model,
|
||||
const std::optional<llama_context_params> ¶ms)
|
||||
: backend_base_t(model),
|
||||
mContext_(llama_context_factory(model)),
|
||||
mWorker_(mModel_, params.value_or(llama_context_default_params())) {
|
||||
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
||||
}
|
||||
|
||||
std::expected<std::size_t, backend_error_t>
|
||||
single_worker_backend_t::generate(
|
||||
std::span<const llama_token> tokens,
|
||||
std::span<llama_token> out,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback
|
||||
) {
|
||||
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback);
|
||||
}
|
||||
}
|
|
@ -8,25 +8,42 @@
|
|||
#include <cmath>
|
||||
#include <expected>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <span>
|
||||
#include <stop_token>
|
||||
#include <vector>
|
||||
|
||||
#include <llama.h>
|
||||
#include <thread>
|
||||
|
||||
#define LLAMA_SUCCESS(x) x == 0
|
||||
|
||||
namespace huggingface::tgi::backends::llamacpp {
|
||||
enum BackendError : uint8_t {
|
||||
|
||||
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
||||
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_smart_ptr;
|
||||
|
||||
typedef std::function<void(llama_token, bool)> llama_decode_callback;
|
||||
static constexpr auto llama_void_callback = [](llama_token token_id, bool is_eos) {};
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
enum backend_error_t : uint8_t {
|
||||
MODEL_FILE_DOESNT_EXIST = 1
|
||||
};
|
||||
|
||||
struct SamplingParams {
|
||||
uint32_t topK = std::numeric_limits<decltype(topK)>::max();
|
||||
float_t topP = 1.0f;
|
||||
float_t frequencyPenalty = 0.0f;
|
||||
float_t repetitionPenalty = 0.0f;
|
||||
/**
|
||||
*
|
||||
*/
|
||||
struct sampling_params_t {
|
||||
uint32_t top_k = std::numeric_limits<decltype(top_k)>::max();
|
||||
float_t top_p = 1.0f;
|
||||
float_t frequency_penalty = 0.0f;
|
||||
float_t repetition_penalty = 0.0f;
|
||||
uint64_t seed = 2014;
|
||||
|
||||
/**
|
||||
|
@ -34,38 +51,72 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @param Pointer to the model data
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<llama_sampler> IntoLlamaSampler(const llama_model *) const;
|
||||
std::unique_ptr<llama_sampler> into_llama_sampler(const llama_model *pModel) const;
|
||||
};
|
||||
|
||||
class Worker {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
struct generation_params_t {
|
||||
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
|
||||
};
|
||||
|
||||
struct generation_context_t {
|
||||
generation_params_t generation_params;
|
||||
sampling_params_t sampling_params;
|
||||
std::span<const llama_token> input_tokens;
|
||||
std::span<llama_token> generated_tokens;
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
class worker_t {
|
||||
private:
|
||||
const std::shared_ptr<llama_model> mModel_;
|
||||
const llama_context_params mParams_;
|
||||
|
||||
public:
|
||||
/**
|
||||
*
|
||||
* @param model
|
||||
* @param params
|
||||
*/
|
||||
worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param context
|
||||
* @param generation_context
|
||||
* @param callback
|
||||
*/
|
||||
size_t
|
||||
generate(llama_context *, const generation_context_t &, const std::optional<llama_decode_callback> &) const;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
void loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const;
|
||||
};
|
||||
|
||||
|
||||
class backend_base_t {
|
||||
|
||||
protected:
|
||||
constexpr static auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
||||
|
||||
public:
|
||||
using model_ptr_type = std::shared_ptr<llama_model>;
|
||||
using context_params_type = llama_context_params;
|
||||
using token_id_type = llama_token;
|
||||
|
||||
private:
|
||||
const model_ptr_type mModel_;
|
||||
context_params_type mParams_;
|
||||
|
||||
public:
|
||||
Worker(std::shared_ptr<llama_model> pModel, const llama_context_params ¶ms);
|
||||
|
||||
void Loop(std::atomic_flag &, std::atomic_uint8_t &, std::queue<SamplingParams> &) const;
|
||||
};
|
||||
|
||||
|
||||
class BackendBase {
|
||||
|
||||
private:
|
||||
std::shared_ptr<llama_model> mModel_;
|
||||
|
||||
public:
|
||||
explicit BackendBase(llama_model *model);
|
||||
|
||||
~BackendBase();
|
||||
/**
|
||||
*
|
||||
* @param model
|
||||
*/
|
||||
explicit backend_base_t(llama_model *model);
|
||||
|
||||
/**
|
||||
* Destructor
|
||||
*/
|
||||
~backend_base_t();
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -76,12 +127,13 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @return
|
||||
*/
|
||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||
std::expected<std::vector<llama_token>, BackendError> Generate(
|
||||
std::span<const llama_token> tokens,
|
||||
std::span<llama_token> out,
|
||||
const SamplingParams ¶ms,
|
||||
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1
|
||||
);
|
||||
virtual std::expected<size_t, backend_error_t> generate(
|
||||
std::span<const llama_token> input_tokens,
|
||||
std::span<llama_token> generated_tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback
|
||||
) = 0;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -91,12 +143,46 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @return
|
||||
*/
|
||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||
std::expected<std::vector<llama_token>, BackendError> Generate(
|
||||
std::expected<std::vector<llama_token>, backend_error_t> generate(
|
||||
std::span<const llama_token> tokens,
|
||||
const SamplingParams ¶ms,
|
||||
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max() - 1
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback = std::nullopt
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
class single_worker_backend_t : backend_base_t {
|
||||
private:
|
||||
constexpr const static auto llama_context_factory = [](llama_model *pModel) -> llama_context_smart_ptr {
|
||||
auto llParams = llama_context_default_params();
|
||||
llParams.flash_attn = true;
|
||||
llParams.n_batch = 1;
|
||||
llParams.no_perf = true;
|
||||
llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
|
||||
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
|
||||
};
|
||||
|
||||
llama_context_smart_ptr mContext_;
|
||||
worker_t mWorker_;
|
||||
|
||||
public:
|
||||
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
|
||||
|
||||
using backend_base_t::generate;
|
||||
|
||||
std::expected<size_t, backend_error_t>
|
||||
generate(
|
||||
std::span<const llama_token> tokens,
|
||||
std::span<llama_token> out,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback
|
||||
) override;
|
||||
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
||||
|
|
Loading…
Reference in New Issue