102 lines
2.7 KiB
C++
102 lines
2.7 KiB
C++
//
|
|
// Created by Morgan Funtowicz on 9/28/2024.
|
|
//
|
|
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
|
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
|
|
|
#include <atomic>
|
|
#include <cmath>
|
|
#include <expected>
|
|
#include <filesystem>
|
|
#include <functional>
|
|
#include <queue>
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <span>
|
|
#include <stop_token>
|
|
#include <vector>
|
|
|
|
#include <llama.h>
|
|
#include <thread>
|
|
|
|
#define LLAMA_SUCCESS(x) x == 0
|
|
|
|
namespace huggingface::tgi::backends::llamacpp {
|
|
|
|
static constexpr auto llama_context_deleter = [](llama_context *pContext) { llama_free(pContext); };
|
|
typedef std::unique_ptr<llama_context, decltype(llama_context_deleter)> llama_context_ptr;
|
|
|
|
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
|
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
|
|
|
typedef std::function<bool(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
|
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) -> bool { return false; };
|
|
|
|
/**
|
|
*
|
|
*/
|
|
enum backend_error_t : uint8_t {
|
|
MODEL_FILE_DOESNT_EXIST = 1
|
|
};
|
|
|
|
/**
|
|
*
|
|
*/
|
|
struct sampling_params_t {
|
|
uint32_t top_k = std::numeric_limits<decltype(top_k)>::max();
|
|
float_t top_p = 1.0f;
|
|
float_t frequency_penalty = 0.0f;
|
|
float_t repetition_penalty = 0.0f;
|
|
uint64_t seed = 2014;
|
|
|
|
/**
|
|
* Convert this GenerationParams to the respective llama_sampler structure
|
|
* @param Pointer to the model data
|
|
* @return
|
|
*/
|
|
llama_sampler_ptr into_llama_sampler(const llama_model *pModel) const;
|
|
};
|
|
|
|
/**
|
|
*
|
|
*/
|
|
struct generation_params_t {
|
|
uint32_t max_new_tokens = std::numeric_limits<uint32_t>::max();
|
|
bool ignore_eos_token = false;
|
|
};
|
|
|
|
struct generation_context_t {
|
|
generation_params_t generation_params;
|
|
sampling_params_t sampling_params;
|
|
std::span<const llama_token> input_tokens;
|
|
};
|
|
|
|
/**
|
|
*
|
|
*/
|
|
class worker_t {
|
|
private:
|
|
std::shared_ptr<llama_model> model_;
|
|
llama_context_ptr context_;
|
|
|
|
public:
|
|
/**
|
|
*
|
|
* @param model
|
|
* @param params
|
|
*/
|
|
worker_t(std::shared_ptr<llama_model>, const llama_context_params &&);
|
|
|
|
/**
|
|
*
|
|
* @param context
|
|
* @param generation_context
|
|
* @param callback
|
|
*/
|
|
[[nodiscard]] std::expected<size_t, backend_error_t>
|
|
generate(const generation_context_t &, const std::optional<llama_decode_callback> &) const;
|
|
};
|
|
}
|
|
|
|
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|