2024-10-03 06:00:17 -06:00
|
|
|
//
|
|
|
|
// Created by Morgan Funtowicz on 9/28/2024.
|
|
|
|
//
|
|
|
|
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
|
|
|
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
|
|
|
|
2024-10-29 15:30:36 -06:00
|
|
|
#include <atomic>
|
2024-10-22 16:09:10 -06:00
|
|
|
#include <cmath>
|
|
|
|
#include <expected>
|
2024-10-22 07:22:56 -06:00
|
|
|
#include <filesystem>
|
2024-10-30 15:40:37 -06:00
|
|
|
#include <functional>
|
2024-10-29 15:30:36 -06:00
|
|
|
#include <queue>
|
2024-10-03 06:00:17 -06:00
|
|
|
#include <memory>
|
2024-10-30 15:40:37 -06:00
|
|
|
#include <optional>
|
2024-10-24 01:56:40 -06:00
|
|
|
#include <span>
|
2024-10-30 15:40:37 -06:00
|
|
|
#include <stop_token>
|
2024-10-24 08:42:50 -06:00
|
|
|
#include <vector>
|
|
|
|
|
2024-10-03 06:00:17 -06:00
|
|
|
#include <llama.h>
|
2024-10-30 15:40:37 -06:00
|
|
|
#include <thread>
|
2024-10-03 06:00:17 -06:00
|
|
|
|
2024-10-23 06:12:32 -06:00
|
|
|
#define LLAMA_SUCCESS(x) x == 0
|
2024-10-03 06:00:17 -06:00
|
|
|
|
2024-10-24 08:42:50 -06:00
|
|
|
namespace huggingface::tgi::backends::llamacpp {
|
2024-10-30 15:40:37 -06:00
|
|
|
|
|
|
|
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
2024-11-03 03:17:02 -07:00
|
|
|
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_ptr;
|
|
|
|
|
|
|
|
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
|
|
|
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
2024-10-30 15:40:37 -06:00
|
|
|
|
2024-11-04 09:01:22 -07:00
|
|
|
typedef std::function<bool(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
|
|
|
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) -> bool { return false; };
|
2024-10-30 15:40:37 -06:00
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
*/
|
|
|
|
enum backend_error_t : uint8_t {
|
2024-10-22 07:22:56 -06:00
|
|
|
MODEL_FILE_DOESNT_EXIST = 1
|
|
|
|
};
|
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
*/
|
|
|
|
struct sampling_params_t {
|
|
|
|
uint32_t top_k = std::numeric_limits<decltype(top_k)>::max();
|
|
|
|
float_t top_p = 1.0f;
|
|
|
|
float_t frequency_penalty = 0.0f;
|
|
|
|
float_t repetition_penalty = 0.0f;
|
2024-10-29 15:30:36 -06:00
|
|
|
uint64_t seed = 2014;
|
2024-10-22 16:09:10 -06:00
|
|
|
|
|
|
|
/**
|
2024-10-29 15:30:36 -06:00
|
|
|
* Convert this GenerationParams to the respective llama_sampler structure
|
|
|
|
* @param Pointer to the model data
|
2024-10-22 16:09:10 -06:00
|
|
|
* @return
|
|
|
|
*/
|
2024-11-03 03:17:02 -07:00
|
|
|
llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const;
|
2024-10-29 15:30:36 -06:00
|
|
|
};
|
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
*/
|
|
|
|
struct generation_params_t {
|
|
|
|
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
|
2024-10-31 14:32:29 -06:00
|
|
|
bool ignore_eos_token = false;
|
2024-10-30 15:40:37 -06:00
|
|
|
};
|
2024-10-22 16:09:10 -06:00
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
struct generation_context_t {
|
|
|
|
generation_params_t generation_params;
|
|
|
|
sampling_params_t sampling_params;
|
|
|
|
std::span<const llama_token> input_tokens;
|
|
|
|
};
|
2024-10-29 15:30:36 -06:00
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
*/
|
|
|
|
class worker_t {
|
2024-10-29 15:30:36 -06:00
|
|
|
private:
|
2024-10-30 15:40:37 -06:00
|
|
|
const std::shared_ptr<llama_model> mModel_;
|
|
|
|
const llama_context_params mParams_;
|
2024-10-24 08:42:50 -06:00
|
|
|
|
2024-10-29 15:30:36 -06:00
|
|
|
public:
|
2024-10-30 15:40:37 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param model
|
|
|
|
* @param params
|
|
|
|
*/
|
|
|
|
worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms);
|
2024-10-24 08:42:50 -06:00
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param context
|
|
|
|
* @param generation_context
|
|
|
|
* @param callback
|
|
|
|
*/
|
|
|
|
size_t
|
|
|
|
generate(llama_context *, const generation_context_t &, const std::optional<llama_decode_callback> &) const;
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
*/
|
|
|
|
void loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const;
|
2024-10-29 15:30:36 -06:00
|
|
|
};
|
|
|
|
|
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
class backend_base_t {
|
2024-10-29 15:30:36 -06:00
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
protected:
|
2024-10-29 15:30:36 -06:00
|
|
|
std::shared_ptr<llama_model> mModel_;
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
2024-10-30 15:40:37 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param model
|
|
|
|
*/
|
|
|
|
explicit backend_base_t(llama_model *model);
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Destructor
|
|
|
|
*/
|
|
|
|
~backend_base_t();
|
2024-10-04 02:42:31 -06:00
|
|
|
|
2024-10-22 16:09:10 -06:00
|
|
|
/**
|
|
|
|
*
|
2024-10-29 15:30:36 -06:00
|
|
|
* @param tokens
|
2024-11-04 08:17:43 -07:00
|
|
|
* @param generation_params
|
|
|
|
* @param sampling_params
|
|
|
|
* @param callback
|
2024-10-22 16:09:10 -06:00
|
|
|
* @return
|
|
|
|
*/
|
2024-10-29 15:30:36 -06:00
|
|
|
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
2024-11-04 08:17:43 -07:00
|
|
|
std::expected<std::vector<llama_token>, backend_error_t> generate(
|
|
|
|
std::span<const llama_token> tokens,
|
2024-10-30 15:40:37 -06:00
|
|
|
const generation_params_t &generation_params,
|
|
|
|
const sampling_params_t &sampling_params,
|
2024-11-04 08:17:43 -07:00
|
|
|
const std::optional<llama_decode_callback> &callback = std::nullopt
|
|
|
|
);
|
2024-10-22 16:09:10 -06:00
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param tokens
|
2024-11-04 08:17:43 -07:00
|
|
|
* @param generation_params
|
|
|
|
* @param sampling_params
|
|
|
|
* @params callback
|
2024-10-22 16:09:10 -06:00
|
|
|
* @return
|
|
|
|
*/
|
2024-10-23 06:12:52 -06:00
|
|
|
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
2024-11-04 08:17:43 -07:00
|
|
|
virtual std::expected<size_t, backend_error_t> stream(
|
2024-10-29 15:30:36 -06:00
|
|
|
std::span<const llama_token> tokens,
|
2024-10-30 15:40:37 -06:00
|
|
|
const generation_params_t &generation_params,
|
|
|
|
const sampling_params_t &sampling_params,
|
2024-11-04 08:17:43 -07:00
|
|
|
const llama_decode_callback &callback
|
|
|
|
) = 0;
|
2024-10-03 06:00:17 -06:00
|
|
|
};
|
2024-10-30 15:40:37 -06:00
|
|
|
|
|
|
|
|
|
|
|
class single_worker_backend_t : backend_base_t {
|
|
|
|
private:
|
2024-11-05 15:48:13 -07:00
|
|
|
constexpr static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
|
2024-10-30 15:40:37 -06:00
|
|
|
auto llParams = llama_context_default_params();
|
|
|
|
llParams.flash_attn = true;
|
|
|
|
llParams.n_batch = 1;
|
2024-11-05 15:48:13 -07:00
|
|
|
llParams.n_threads = 1;
|
2024-10-30 15:40:37 -06:00
|
|
|
llParams.no_perf = true;
|
|
|
|
llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL;
|
|
|
|
|
|
|
|
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
|
|
|
|
};
|
|
|
|
|
2024-11-03 03:17:02 -07:00
|
|
|
llama_context_ptr mContext_;
|
2024-10-30 15:40:37 -06:00
|
|
|
worker_t mWorker_;
|
|
|
|
|
|
|
|
public:
|
|
|
|
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
|
|
|
|
|
2024-11-05 15:48:13 -07:00
|
|
|
using backend_base_t::generate;
|
|
|
|
|
2024-11-04 08:17:43 -07:00
|
|
|
std::expected<size_t, backend_error_t> stream(
|
2024-10-30 15:40:37 -06:00
|
|
|
std::span<const llama_token> tokens,
|
|
|
|
const generation_params_t &generation_params,
|
|
|
|
const sampling_params_t &sampling_params,
|
2024-11-04 08:17:43 -07:00
|
|
|
const llama_decode_callback &callback) override;
|
2024-10-31 10:51:57 -06:00
|
|
|
};
|
2024-10-30 15:40:37 -06:00
|
|
|
|
2024-10-31 10:51:57 -06:00
|
|
|
class multi_worker_backend_t : backend_base_t {
|
|
|
|
private:
|
2024-11-03 03:17:02 -07:00
|
|
|
llama_context_ptr mContext_;
|
2024-10-30 15:40:37 -06:00
|
|
|
|
2024-10-31 10:51:57 -06:00
|
|
|
public:
|
2024-11-05 15:48:13 -07:00
|
|
|
using backend_base_t::generate;
|
|
|
|
|
2024-11-04 08:17:43 -07:00
|
|
|
std::expected<size_t, backend_error_t> stream(
|
|
|
|
std::span<const llama_token> tokens,
|
2024-10-31 10:51:57 -06:00
|
|
|
const generation_params_t &generation_params,
|
|
|
|
const sampling_params_t &sampling_params,
|
2024-11-04 08:17:43 -07:00
|
|
|
const llama_decode_callback &callback) override;
|
2024-10-30 15:40:37 -06:00
|
|
|
};
|
2024-10-03 06:00:17 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|