2024-10-03 06:00:17 -06:00
|
|
|
//
|
|
|
|
// Created by Morgan Funtowicz on 9/28/2024.
|
|
|
|
//
|
|
|
|
#ifndef TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
|
|
|
#define TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|
|
|
|
|
2024-10-22 16:09:10 -06:00
|
|
|
#include <cmath>
|
|
|
|
#include <expected>
|
2024-10-22 07:22:56 -06:00
|
|
|
#include <filesystem>
|
2024-10-03 06:00:17 -06:00
|
|
|
#include <memory>
|
|
|
|
#include <llama.h>
|
|
|
|
|
2024-10-23 06:12:32 -06:00
|
|
|
#define LLAMA_SUCCESS(x) x == 0
|
2024-10-03 06:00:17 -06:00
|
|
|
|
2024-10-22 16:09:10 -06:00
|
|
|
namespace huggingface::tgi::backends::llama {
|
2024-10-22 07:22:56 -06:00
|
|
|
enum TgiLlamaCppBackendError {
|
|
|
|
MODEL_FILE_DOESNT_EXIST = 1
|
|
|
|
};
|
|
|
|
|
2024-10-03 06:00:17 -06:00
|
|
|
|
|
|
|
class TgiLlamaCppBackend {
|
2024-10-22 16:10:41 -06:00
|
|
|
using TokenId = llama_token;
|
2024-10-22 16:09:10 -06:00
|
|
|
|
2024-10-03 06:00:17 -06:00
|
|
|
private:
|
|
|
|
llama_model* model;
|
|
|
|
llama_context* ctx;
|
2024-10-22 16:09:10 -06:00
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param topK
|
|
|
|
* @param topP
|
|
|
|
* @return
|
|
|
|
*/
|
|
|
|
std::unique_ptr<llama_sampler *> GetSamplerFromArgs(
|
|
|
|
uint32_t topK, float_t topP, float_t frequencyPenalty, float_t repetitionPenalty, uint64_t seed);
|
|
|
|
|
2024-10-03 06:00:17 -06:00
|
|
|
public:
|
2024-10-04 02:42:31 -06:00
|
|
|
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
2024-10-03 06:00:17 -06:00
|
|
|
~TgiLlamaCppBackend();
|
2024-10-04 02:42:31 -06:00
|
|
|
|
2024-10-22 16:09:10 -06:00
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param text
|
|
|
|
* @return
|
|
|
|
*/
|
|
|
|
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Tokenize(const std::string& text) const;
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @param tokens
|
|
|
|
* @param topK
|
|
|
|
* @param topP
|
|
|
|
* @param maxNewTokens
|
|
|
|
* @return
|
|
|
|
*/
|
|
|
|
[[nodiscard]] std::vector<TgiLlamaCppBackend::TokenId> Generate(
|
|
|
|
std::span<const TokenId> tokens,
|
|
|
|
uint32_t topK,
|
|
|
|
float_t topP = 1.0f,
|
|
|
|
uint32_t maxNewTokens = std::numeric_limits<uint32_t>::max()
|
|
|
|
);
|
2024-10-03 06:00:17 -06:00
|
|
|
};
|
|
|
|
|
2024-10-22 07:22:56 -06:00
|
|
|
std::expected<std::unique_ptr<TgiLlamaCppBackend>, TgiLlamaCppBackendError>
|
|
|
|
CreateLlamaCppBackend(const std::filesystem::path& root);
|
2024-10-03 06:00:17 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
#endif //TGI_LLAMA_CPP_BACKEND_BACKEND_HPP
|