feat(backend): simplify overall cpp structure
This commit is contained in:
parent
4f5397c414
commit
86d30aea43
|
@ -49,43 +49,28 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
}
|
||||
|
||||
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
|
||||
return llama_sampler_ptr(pSampler, llama_sampler_deleter);
|
||||
return {pSampler, llama_sampler_deleter};
|
||||
}
|
||||
|
||||
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms)
|
||||
: mModel_(model), mParams_(params) {
|
||||
: model_(model), context_(llama_new_context_with_model(model_.get(), params)) {
|
||||
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
char modelName[256];
|
||||
llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName));
|
||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||
#endif
|
||||
}
|
||||
|
||||
void worker_t::loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const {
|
||||
auto *context = llama_new_context_with_model(mModel_.get(), mParams_);
|
||||
|
||||
while (!driver.stop_requested()) {
|
||||
const auto generation_context = backlog.front();
|
||||
|
||||
generate(context, generation_context, std::nullopt);
|
||||
backlog.pop();
|
||||
|
||||
SPDLOG_DEBUG("Processed request ({:d} remaining)", backlog.size());
|
||||
}
|
||||
|
||||
llama_free(context);
|
||||
}
|
||||
|
||||
size_t worker_t::generate(
|
||||
llama_context *context,
|
||||
const generation_context_t &generation_context,
|
||||
const std::optional<llama_decode_callback> &callback) const {
|
||||
std::expected<size_t, backend_error_t>
|
||||
worker_t::generate(const generation_context_t &generation_context,
|
||||
const std::optional<llama_decode_callback> &callback) const {
|
||||
// Store information about context and generation size
|
||||
const auto callback_ = callback.value_or(llama_void_callback);
|
||||
auto max_new_tokens = generation_context.generation_params.max_new_tokens;
|
||||
|
||||
// Convert sampling params to what llama.cpp is looking for
|
||||
auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get());
|
||||
auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get());
|
||||
|
||||
// Set up the prompt
|
||||
auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end());
|
||||
|
@ -94,11 +79,10 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
// Decode
|
||||
auto n_decoded_tokens = 0;
|
||||
for (bool generating = true; generating; ++n_decoded_tokens) {
|
||||
const auto callback_ = callback.value_or(llama_void_callback);
|
||||
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
const auto start = std::chrono::steady_clock::now();
|
||||
const auto status = llama_decode(context, batch);
|
||||
const auto status = llama_decode(context_.get(), batch);
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
||||
SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency);
|
||||
|
@ -108,8 +92,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
batch.n_tokens = 0;
|
||||
if (LLAMA_SUCCESS(status)) [[likely]] {
|
||||
// Sample the new token
|
||||
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
|
||||
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
|
||||
auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1);
|
||||
auto is_eog = llama_token_is_eog(model_.get(), new_token_id);
|
||||
auto new_token_logits = 0.0f; // TODO: return logit
|
||||
|
||||
// Handle termination cases
|
||||
|
@ -119,11 +103,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
generating = !(has_reach_max_tokens | has_reach_eog);
|
||||
|
||||
// Bubble up the generated token if a callback is provided
|
||||
const auto should_stop = std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||
new_token_id,
|
||||
new_token_logits,
|
||||
!generating,
|
||||
n_decoded_tokens + 1);
|
||||
const auto should_stop =
|
||||
std::invoke(callback_, new_token_id, new_token_logits, !generating, n_decoded_tokens + 1);
|
||||
generating ^= should_stop;
|
||||
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
|
@ -132,62 +113,4 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
|
||||
return n_decoded_tokens;
|
||||
}
|
||||
|
||||
|
||||
backend_base_t::backend_base_t(llama_model *model) : mModel_(model, llama_free_model) { llama_backend_init(); }
|
||||
|
||||
backend_base_t::~backend_base_t() { llama_backend_free(); }
|
||||
|
||||
std::expected<std::vector<llama_token>, backend_error_t> backend_base_t::generate(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback
|
||||
) {
|
||||
// TODO: Should we provide a way to change this value?
|
||||
auto generated = std::vector<llama_token>(2 << 8);
|
||||
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
|
||||
size_t num_generated_tokens) -> bool {
|
||||
generated.emplace_back(new_token_id);
|
||||
|
||||
if (callback.has_value())
|
||||
return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
|
||||
return generated;
|
||||
}
|
||||
|
||||
|
||||
/** Single worker_t Backend impl **/
|
||||
|
||||
single_worker_backend_t::single_worker_backend_t(llama_model *model,
|
||||
const std::optional<llama_context_params> ¶ms)
|
||||
: backend_base_t(model),
|
||||
mContext_(llama_context_factory(model)),
|
||||
mWorker_(mModel_, params.value_or(llama_context_default_params())) {
|
||||
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
||||
};
|
||||
|
||||
std::expected<size_t, backend_error_t>
|
||||
single_worker_backend_t::stream(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const llama_decode_callback &callback
|
||||
) {
|
||||
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback);
|
||||
}
|
||||
|
||||
std::expected<size_t, backend_error_t>
|
||||
multi_worker_backend_t::stream(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const llama_decode_callback &callback
|
||||
) {
|
||||
SPDLOG_WARN("Not implemented for multi_worker_t");
|
||||
return 0;
|
||||
}
|
||||
}
|
|
@ -76,8 +76,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
*/
|
||||
class worker_t {
|
||||
private:
|
||||
const std::shared_ptr<llama_model> mModel_;
|
||||
const llama_context_params mParams_;
|
||||
std::shared_ptr<llama_model> model_;
|
||||
llama_context_ptr context_;
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -85,7 +85,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @param model
|
||||
* @param params
|
||||
*/
|
||||
worker_t(std::shared_ptr<llama_model> model, const llama_context_params ¶ms);
|
||||
worker_t(std::shared_ptr<llama_model>, const llama_context_params &);
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -93,108 +93,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @param generation_context
|
||||
* @param callback
|
||||
*/
|
||||
size_t
|
||||
generate(llama_context *, const generation_context_t &, const std::optional<llama_decode_callback> &) const;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
void loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const;
|
||||
};
|
||||
|
||||
|
||||
class backend_base_t {
|
||||
|
||||
protected:
|
||||
std::shared_ptr<llama_model> mModel_;
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
*
|
||||
* @param model
|
||||
*/
|
||||
explicit backend_base_t(llama_model *model);
|
||||
|
||||
/**
|
||||
* Destructor
|
||||
*/
|
||||
~backend_base_t();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param tokens
|
||||
* @param generation_params
|
||||
* @param sampling_params
|
||||
* @param callback
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||
std::expected<std::vector<llama_token>, backend_error_t> generate(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const std::optional<llama_decode_callback> &callback = std::nullopt
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param tokens
|
||||
* @param generation_params
|
||||
* @param sampling_params
|
||||
* @params callback
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
|
||||
virtual std::expected<size_t, backend_error_t> stream(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const llama_decode_callback &callback
|
||||
) = 0;
|
||||
};
|
||||
|
||||
|
||||
class single_worker_backend_t : backend_base_t {
|
||||
private:
|
||||
constexpr static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
|
||||
auto llParams = llama_context_default_params();
|
||||
llParams.flash_attn = true;
|
||||
llParams.n_batch = 1;
|
||||
llParams.n_threads = 1;
|
||||
llParams.no_perf = true;
|
||||
llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL;
|
||||
|
||||
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
|
||||
};
|
||||
|
||||
llama_context_ptr mContext_;
|
||||
worker_t mWorker_;
|
||||
|
||||
public:
|
||||
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
|
||||
|
||||
using backend_base_t::generate;
|
||||
|
||||
std::expected<size_t, backend_error_t> stream(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const llama_decode_callback &callback) override;
|
||||
};
|
||||
|
||||
class multi_worker_backend_t : backend_base_t {
|
||||
private:
|
||||
llama_context_ptr mContext_;
|
||||
|
||||
public:
|
||||
using backend_base_t::generate;
|
||||
|
||||
std::expected<size_t, backend_error_t> stream(
|
||||
std::span<const llama_token> tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
const llama_decode_callback &callback) override;
|
||||
[[nodiscard]] std::expected<size_t, backend_error_t>
|
||||
generate(const generation_context_t &, const std::optional<llama_decode_callback> &) const;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -7,58 +7,41 @@
|
|||
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <string_view>
|
||||
#include <variant>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "backend.hpp"
|
||||
|
||||
namespace huggingface::tgi::backends::llamacpp {
|
||||
struct generation_params_t;
|
||||
struct sampling_params_t;
|
||||
|
||||
class llama_cpp_backend_impl_t;
|
||||
class llama_cpp_worker_frontend_t;
|
||||
}
|
||||
|
||||
|
||||
#include "backend.hpp"
|
||||
#include "backends/llamacpp/src/lib.rs.h"
|
||||
#include "rust/cxx.h"
|
||||
|
||||
|
||||
namespace huggingface::tgi::backends::llamacpp {
|
||||
|
||||
// Concept identifying types which have a .generate() -> size_t method to do in-place generation
|
||||
template<typename T>
|
||||
concept has_stream_method = requires(
|
||||
T t,
|
||||
std::span<const llama_token> input_tokens,
|
||||
const generation_params_t &generation_params,
|
||||
const sampling_params_t &sampling_params,
|
||||
llama_decode_callback callback
|
||||
) {
|
||||
{
|
||||
t.stream(input_tokens, generation_params, sampling_params, callback)
|
||||
} -> std::same_as<std::expected<size_t, backend_error_t>>;
|
||||
auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
|
||||
auto make_shared_llama_model = [](llama_model *model) {
|
||||
return std::shared_ptr<llama_model>(model, llama_model_deleter);
|
||||
};
|
||||
|
||||
static_assert(has_stream_method<single_worker_backend_t>, "single_worker_backend_t doesn't meet concept has_stream_method");
|
||||
static_assert(has_stream_method<multi_worker_backend_t>, "multi_worker_backend_t doesn't meet concept has_stream_method");
|
||||
|
||||
class llama_cpp_backend_exception_t : std::exception {
|
||||
|
||||
};
|
||||
class llama_cpp_backend_exception_t : std::exception {};
|
||||
|
||||
/**
|
||||
* Llama.cpp backend interfacing with Rust FFI layer
|
||||
* Llama.cpp frontend over the worker interfacing with Rust FFI layer
|
||||
*/
|
||||
class llama_cpp_backend_impl_t {
|
||||
class llama_cpp_worker_frontend_t {
|
||||
private:
|
||||
std::variant<single_worker_backend_t, multi_worker_backend_t> mInner_;
|
||||
std::shared_ptr<llama_model> model_;
|
||||
worker_t worker_;
|
||||
|
||||
public:
|
||||
explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
|
||||
|
||||
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
|
||||
explicit llama_cpp_worker_frontend_t(llama_model *model):
|
||||
model_{ make_shared_llama_model(model) }, worker_(model_, {.no_perf = true}) {}
|
||||
|
||||
size_t stream(
|
||||
rust::Slice<const uint32_t> input_tokens,
|
||||
|
@ -67,41 +50,31 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
InferContext *ctx,
|
||||
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
||||
) {
|
||||
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
|
||||
-> std::expected<size_t, backend_error_t> {
|
||||
|
||||
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
|
||||
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
||||
};
|
||||
|
||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||
auto input_tokens_v =
|
||||
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
||||
|
||||
return backend.stream(
|
||||
input_tokens_v,
|
||||
generation_params,
|
||||
sampling_params,
|
||||
context_forwarding_callback
|
||||
);
|
||||
auto context_forwarding_callback =
|
||||
[=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
|
||||
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
||||
};
|
||||
|
||||
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {
|
||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
|
||||
auto input_tokens_v =
|
||||
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
|
||||
|
||||
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
|
||||
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
|
||||
return *result;
|
||||
} else {
|
||||
throw llama_cpp_backend_exception_t();
|
||||
throw llama_cpp_backend_exception_t {};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<llama_cpp_backend_impl_t> create_single_worker_backend(rust::Str modelPath) {
|
||||
std::unique_ptr<llama_cpp_worker_frontend_t> create_worker_frontend(rust::Str modelPath) {
|
||||
const auto cxxPath = std::string(modelPath);
|
||||
auto params = llama_model_default_params();
|
||||
params.use_mmap = true;
|
||||
|
||||
auto *model = llama_load_model_from_file(cxxPath.c_str(), params);
|
||||
return std::make_unique<llama_cpp_backend_impl_t>(single_worker_backend_t { model, std::nullopt });
|
||||
auto *model = (llama_load_model_from_file(cxxPath.c_str(), params));
|
||||
return std::make_unique<llama_cpp_worker_frontend_t>(model);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
//
|
||||
// Created by mfuntowicz on 10/3/24.
|
||||
//
|
||||
#include <memory>
|
||||
|
||||
#include <fmt/color.h>
|
||||
#include <fmt/format.h>
|
||||
#include <fmt/std.h>
|
||||
#include <fmt/ranges.h>
|
||||
#include <llama.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <spdlog/fmt/ranges.h>s
|
||||
#include "../csrc/backend.hpp"
|
||||
|
||||
using namespace huggingface::tgi::backends::llamacpp;
|
||||
|
||||
const auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 2) {
|
||||
fmt::print("No model folder provider");
|
||||
|
@ -18,21 +19,31 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
|
||||
spdlog::set_level(spdlog::level::debug);
|
||||
|
||||
|
||||
const auto modelPath = absolute(std::filesystem::path(argv[1]));
|
||||
const auto params = llama_model_default_params();
|
||||
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
|
||||
auto model = std::unique_ptr<llama_model, decltype(llama_model_deleter)>(
|
||||
llama_load_model_from_file(modelPath.c_str(), params)
|
||||
);
|
||||
|
||||
auto backend = single_worker_backend_t(model, {});
|
||||
auto prompt = "My name is Morgan";
|
||||
auto tokens = std::vector<llama_token>(16);
|
||||
const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true,
|
||||
false);
|
||||
tokens.resize(nb_tokens);
|
||||
auto backend = worker_t{std::move(model), {.n_batch = 1, .n_threads = 4}};
|
||||
|
||||
fmt::println("Tokenized: {}", tokens);
|
||||
|
||||
// generate
|
||||
const auto promptTokens = {128000, 5159, 836, 374, 23809, 11};
|
||||
const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40});
|
||||
|
||||
if (out.has_value())
|
||||
fmt::print(FMT_STRING("Generated: {}"), *out);
|
||||
else {
|
||||
const auto err = out.error();
|
||||
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
|
||||
}
|
||||
auto generated_tokens = std::vector<llama_token>(32);
|
||||
const auto n_generated_tokens = backend.generate(
|
||||
{{.max_new_tokens = 32}, {.top_k = 40}, tokens},
|
||||
[&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool {
|
||||
generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id);
|
||||
return false;
|
||||
}
|
||||
);
|
||||
generated_tokens.resize(n_generated_tokens.value());
|
||||
fmt::println("Generated {} tokens", generated_tokens);
|
||||
}
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use crate::ffi::{
|
||||
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
|
||||
create_worker_frontend, GenerationParams, LlamaCppWorkerFrontend, SamplingParams,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc::{channel, Receiver, Sender};
|
||||
use std::sync::Arc;
|
||||
|
@ -21,7 +22,7 @@ use tracing::{debug, error, info};
|
|||
|
||||
type InferResult = Result<InferStreamResponse, InferError>;
|
||||
|
||||
unsafe impl Send for LlamaCppBackendImpl {}
|
||||
unsafe impl Send for LlamaCppWorkerFrontend {}
|
||||
|
||||
impl From<&ValidParameters> for SamplingParams {
|
||||
fn from(v: &ValidParameters) -> Self {
|
||||
|
@ -68,41 +69,54 @@ pub enum LlamaCppBackendError {
|
|||
ModelInitializationFailed(PathBuf, String),
|
||||
}
|
||||
|
||||
pub struct LlamaCppBackend {
|
||||
backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
_scheduler_handle: JoinHandle<()>,
|
||||
// pub struct LlamaCppBackend {
|
||||
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
// _scheduler_handle: JoinHandle<()>,
|
||||
// }
|
||||
|
||||
struct LlamaCppWorker {
|
||||
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
pub enum LlamaCppBackend {
|
||||
Single(LlamaCppWorker),
|
||||
// Multi(Vec<LlamaCppWorker>)
|
||||
}
|
||||
|
||||
impl LlamaCppBackend {
|
||||
pub fn new<P: AsRef<Path> + Send>(
|
||||
fn allocate_worker(
|
||||
path: &Path,
|
||||
) -> Result<UniquePtr<LlamaCppWorkerFrontend>, LlamaCppBackendError> {
|
||||
create_worker_frontend(&path.display().to_string()).map_err(|ref err| {
|
||||
LlamaCppBackendError::ModelInitializationFailed(path.to_path_buf(), err.to_string())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new<P: AsRef<Path>>(
|
||||
model_path: P,
|
||||
tokenizer: Tokenizer,
|
||||
num_cores_per_instance: u16,
|
||||
) -> Result<Self, LlamaCppBackendError> {
|
||||
let path = Arc::new(model_path.as_ref());
|
||||
let shared_path = Arc::new(model_path);
|
||||
let path = shared_path.deref().as_ref();
|
||||
if !path.exists() {
|
||||
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
||||
path.display().to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
|
||||
LlamaCppBackendError::ModelInitializationFailed(
|
||||
path.to_path_buf(),
|
||||
err.what().to_string(),
|
||||
)
|
||||
})?;
|
||||
let worker = match num_cores_per_instance {
|
||||
0 => {
|
||||
let worker = Self::allocate_worker(path)?;
|
||||
let (sender, receiver) = channel();
|
||||
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
|
||||
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
|
||||
}
|
||||
_ => panic!("No supported yet"),
|
||||
};
|
||||
|
||||
info!(
|
||||
"Successfully initialized llama.cpp backend from {}",
|
||||
path.display()
|
||||
);
|
||||
|
||||
let (submitter, receiver) = channel();
|
||||
let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) };
|
||||
Ok(Self {
|
||||
backlog: submitter,
|
||||
_scheduler_handle: handle,
|
||||
})
|
||||
Ok(worker)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,18 +183,16 @@ fn llama_generate_callback(
|
|||
};
|
||||
|
||||
// Send back to the client
|
||||
let should_stop = if let Err(ref _err) = ctx.stream.send(response) {
|
||||
if let Err(ref _err) = ctx.stream.send(response) {
|
||||
error!("Failed to send back the response to the client, cancelling request");
|
||||
true
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
should_stop
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn scheduler_loop(
|
||||
mut backend: UniquePtr<LlamaCppBackendImpl>,
|
||||
fn scheduler_loop(
|
||||
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
||||
tokenizer: Tokenizer,
|
||||
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
) {
|
||||
|
@ -204,20 +216,23 @@ unsafe fn scheduler_loop(
|
|||
generation,
|
||||
});
|
||||
|
||||
let boxed_ctx = Box::into_raw(ctx);
|
||||
// We leak the box to avoid it being freed after the first callback call
|
||||
// when going out of scope
|
||||
unsafe {
|
||||
let boxed_ctx = Box::into_raw(ctx);
|
||||
if let Err(e) = backend.pin_mut().stream(
|
||||
&input_tokens,
|
||||
generation_params,
|
||||
&sampling_params,
|
||||
boxed_ctx,
|
||||
llama_generate_callback,
|
||||
) {
|
||||
error!("Error while decoding tokens... {}", e.what());
|
||||
}
|
||||
|
||||
if let Err(e) = backend.pin_mut().stream(
|
||||
&input_tokens,
|
||||
generation_params,
|
||||
&sampling_params,
|
||||
boxed_ctx,
|
||||
llama_generate_callback,
|
||||
) {
|
||||
error!("Error while decoding tokens... {}", e.what());
|
||||
// Make sure we re-keep track of the OpaqueStream box
|
||||
let _ = Box::from_raw(boxed_ctx);
|
||||
}
|
||||
|
||||
// Make sure we re-keep track of the OpaqueStream box
|
||||
let _ = Box::from_raw(boxed_ctx);
|
||||
}
|
||||
} else {
|
||||
info!("IPC channel is closed, exiting the scheduler loop");
|
||||
|
@ -244,11 +259,13 @@ impl Backend for LlamaCppBackend {
|
|||
sampling_params,
|
||||
};
|
||||
|
||||
match self.backlog.send((ctx, sx)) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
|
||||
Err(_) => Err(InferError::GenerationError(
|
||||
"Failed to sent the request".to_string(),
|
||||
)),
|
||||
match self {
|
||||
LlamaCppBackend::Single(worker) => match worker.sender.send((ctx, sx)) {
|
||||
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
|
||||
Err(_) => Err(InferError::GenerationError(
|
||||
"Failed to sent the request".to_string(),
|
||||
)),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
Err(InferError::GenerationError(
|
||||
|
|
|
@ -46,14 +46,13 @@ mod ffi {
|
|||
type SamplingParams;
|
||||
|
||||
/// Represent an instance of the llama.cpp backend instance on C++ side
|
||||
#[cxx_name = "llama_cpp_backend_impl_t"]
|
||||
type LlamaCppBackendImpl;
|
||||
#[cxx_name = "llama_cpp_worker_frontend_t"]
|
||||
type LlamaCppWorkerFrontend;
|
||||
|
||||
#[rust_name = "create_single_worker_backend"]
|
||||
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
||||
fn create_worker_frontend(modelPath: &str) -> Result<UniquePtr<LlamaCppWorkerFrontend>>;
|
||||
|
||||
unsafe fn stream(
|
||||
self: Pin<&mut LlamaCppBackendImpl>,
|
||||
self: Pin<&mut LlamaCppWorkerFrontend>,
|
||||
tokens: &[u32],
|
||||
generation_params: GenerationParams,
|
||||
sampling_params: &SamplingParams,
|
||||
|
|
|
@ -37,8 +37,8 @@ struct Args {
|
|||
port: u16,
|
||||
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
|
||||
gguf_path: PathBuf,
|
||||
// #[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
|
||||
// num_model_instance: u16,
|
||||
#[clap(long, env, help = "Number of CPU core per instance(s)")]
|
||||
num_cores_per_instance: Option<u16>,
|
||||
#[clap(long, env, required = true)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
|
@ -95,7 +95,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
hostname,
|
||||
port,
|
||||
gguf_path,
|
||||
// num_model_instance,
|
||||
num_cores_per_instance,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
|
@ -164,7 +164,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
||||
.expect("Failed to retrieve tokenizer");
|
||||
let backend = LlamaCppBackend::new(gguf_path, tokenizer)?;
|
||||
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
|
|
Loading…
Reference in New Issue