hf_text-generation-inference/backends/llamacpp/csrc/ffi.hpp

99 lines
3.7 KiB
C++
Raw Normal View History

2024-10-24 01:56:40 -06:00
//
// Created by mfuntowicz on 10/23/24.
//
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP
#include <exception>
#include <filesystem>
#include <memory>
#include <ranges>
#include <string_view>
#include <spdlog/spdlog.h>
namespace huggingface::tgi::backends::llamacpp {
class llama_cpp_worker_frontend_t;
}
#include "backend.hpp"
#include "backends/llamacpp/src/lib.rs.h"
#include "rust/cxx.h"
2024-10-24 01:56:40 -06:00
namespace huggingface::tgi::backends::llamacpp {
auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
auto make_shared_llama_model = [](llama_model *model) {
return std::shared_ptr<llama_model>(model, llama_model_deleter);
};
class llama_cpp_backend_exception_t : std::exception {};
/**
* Llama.cpp frontend over the worker interfacing with Rust FFI layer
*/
class llama_cpp_worker_frontend_t {
private:
std::shared_ptr<llama_model> model_;
worker_t worker_;
2024-10-24 01:56:40 -06:00
public:
explicit llama_cpp_worker_frontend_t(llama_model *model):
model_{ make_shared_llama_model(model) }, worker_(model_, {.no_perf = true}) {}
size_t stream(
rust::Slice<const uint32_t> input_tokens,
const generation_params_t generation_params,
const sampling_params_t &sampling_params,
InferContext *ctx,
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
) {
2024-11-12 16:08:26 -07:00
// Wrapper around the provided Rust callback to inject the InferContext when returning from the C++ FFI boundaries
// It captures the context (ctx) using reference and will automatically call the Rust callback forwarding the InferContext
auto context_forwarding_callback =
[=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
};
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
static auto as_llama_token = [](const uint32_t x){ return static_cast<llama_token>(x); };
#ifdef __cpp_lib_ranges_to_container
auto input_tokens_v = input_tokens | std::views::transform(as_llama_token) | std::ranges::to<std::vector>();
#else
auto input_tokens_ = input_tokens | std::views::transform(as_llama_token);
auto input_tokens_v = std::vector<llama_token>(input_tokens_.begin(), input_tokens_.end());
#endif
// Defer the generation to the actual worker_t
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
return *result;
} else {
throw llama_cpp_backend_exception_t {};
}
}
};
std::unique_ptr<llama_cpp_worker_frontend_t> create_worker_frontend(rust::Str modelPath) {
2024-11-12 16:08:26 -07:00
// Initialize the numa context from numactl
static const bool INITIALIZED_NUMA_CONTEXT_ONCE = [](){
llama_numa_init(GGML_NUMA_STRATEGY_NUMACTL);
return true;
}();
// Allocate model weights parameters
auto params = llama_model_default_params();
params.use_mmap = true;
2024-11-12 16:08:26 -07:00
// Allocate the model from the Rust provided, string path
auto *model = (llama_load_model_from_file(static_cast<std::string>(modelPath).c_str(), params));
return std::make_unique<llama_cpp_worker_frontend_t>(model);
}
2024-10-24 01:56:40 -06:00
}
#endif //TGI_LLAMA_CPP_BACKEND_FFI_HPP