2024-10-24 01:56:40 -06:00
|
|
|
//
|
|
|
|
// Created by mfuntowicz on 10/23/24.
|
|
|
|
//
|
|
|
|
|
|
|
|
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
|
|
|
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP
|
|
|
|
|
2024-10-24 08:42:50 -06:00
|
|
|
#include <exception>
|
|
|
|
#include <filesystem>
|
2024-11-09 14:10:33 -07:00
|
|
|
#include <memory>
|
2024-10-24 08:42:50 -06:00
|
|
|
#include <string_view>
|
2024-11-05 15:47:22 -07:00
|
|
|
#include <variant>
|
2024-10-24 08:42:50 -06:00
|
|
|
|
|
|
|
#include <spdlog/spdlog.h>
|
|
|
|
|
2024-10-31 10:51:57 -06:00
|
|
|
namespace huggingface::tgi::backends::llamacpp {
|
2024-11-09 14:10:33 -07:00
|
|
|
class llama_cpp_worker_frontend_t;
|
2024-10-24 08:42:50 -06:00
|
|
|
}
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
#include "backend.hpp"
|
2024-10-24 08:42:50 -06:00
|
|
|
#include "backends/llamacpp/src/lib.rs.h"
|
2024-11-02 17:36:32 -06:00
|
|
|
#include "rust/cxx.h"
|
2024-10-24 01:56:40 -06:00
|
|
|
|
|
|
|
|
2024-10-31 10:51:57 -06:00
|
|
|
namespace huggingface::tgi::backends::llamacpp {
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
|
|
|
|
auto make_shared_llama_model = [](llama_model *model) {
|
|
|
|
return std::shared_ptr<llama_model>(model, llama_model_deleter);
|
2024-10-31 10:51:57 -06:00
|
|
|
};
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
class llama_cpp_backend_exception_t : std::exception {};
|
2024-10-24 08:42:50 -06:00
|
|
|
|
2024-10-31 10:51:57 -06:00
|
|
|
/**
|
2024-11-09 14:10:33 -07:00
|
|
|
* Llama.cpp frontend over the worker interfacing with Rust FFI layer
|
2024-10-31 10:51:57 -06:00
|
|
|
*/
|
2024-11-09 14:10:33 -07:00
|
|
|
class llama_cpp_worker_frontend_t {
|
2024-10-24 08:42:50 -06:00
|
|
|
private:
|
2024-11-09 14:10:33 -07:00
|
|
|
std::shared_ptr<llama_model> model_;
|
|
|
|
worker_t worker_;
|
2024-10-24 01:56:40 -06:00
|
|
|
|
2024-10-24 08:42:50 -06:00
|
|
|
public:
|
2024-11-09 14:10:33 -07:00
|
|
|
explicit llama_cpp_worker_frontend_t(llama_model *model):
|
|
|
|
model_{ make_shared_llama_model(model) }, worker_(model_, {.no_perf = true}) {}
|
2024-10-24 08:42:50 -06:00
|
|
|
|
2024-11-02 17:36:32 -06:00
|
|
|
size_t stream(
|
2024-10-31 10:51:57 -06:00
|
|
|
rust::Slice<const uint32_t> input_tokens,
|
2024-11-02 17:36:32 -06:00
|
|
|
const generation_params_t generation_params,
|
2024-10-31 10:51:57 -06:00
|
|
|
const sampling_params_t &sampling_params,
|
2024-11-04 08:17:43 -07:00
|
|
|
InferContext *ctx,
|
2024-11-04 09:01:22 -07:00
|
|
|
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
2024-10-31 10:51:57 -06:00
|
|
|
) {
|
2024-11-09 14:10:33 -07:00
|
|
|
auto context_forwarding_callback =
|
|
|
|
[=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
|
|
|
|
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
2024-10-31 10:51:57 -06:00
|
|
|
};
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
|
|
|
|
auto input_tokens_v =
|
|
|
|
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
|
|
|
|
|
|
|
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
|
|
|
|
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
|
2024-10-31 10:51:57 -06:00
|
|
|
return *result;
|
|
|
|
} else {
|
2024-11-09 14:10:33 -07:00
|
|
|
throw llama_cpp_backend_exception_t {};
|
2024-10-31 10:51:57 -06:00
|
|
|
}
|
2024-10-24 08:42:50 -06:00
|
|
|
}
|
2024-10-31 10:51:57 -06:00
|
|
|
};
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
std::unique_ptr<llama_cpp_worker_frontend_t> create_worker_frontend(rust::Str modelPath) {
|
2024-10-31 10:51:57 -06:00
|
|
|
const auto cxxPath = std::string(modelPath);
|
|
|
|
auto params = llama_model_default_params();
|
|
|
|
params.use_mmap = true;
|
|
|
|
|
2024-11-09 14:10:33 -07:00
|
|
|
auto *model = (llama_load_model_from_file(cxxPath.c_str(), params));
|
|
|
|
return std::make_unique<llama_cpp_worker_frontend_t>(model);
|
2024-10-24 08:42:50 -06:00
|
|
|
}
|
2024-10-24 01:56:40 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#endif //TGI_LLAMA_CPP_BACKEND_FFI_HPP
|