feat(backend): simplify overall cpp structure

This commit is contained in:
Morgan Funtowicz 2024-11-09 22:10:33 +01:00
parent 4f5397c414
commit 86d30aea43
7 changed files with 144 additions and 321 deletions

View File

@ -49,43 +49,28 @@ namespace huggingface::tgi::backends::llamacpp {
}
llama_sampler_chain_add(pSampler, llama_sampler_init_dist(seed));
return llama_sampler_ptr(pSampler, llama_sampler_deleter);
return {pSampler, llama_sampler_deleter};
}
worker_t::worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params)
: mModel_(model), mParams_(params) {
: model_(model), context_(llama_new_context_with_model(model_.get(), params)) {
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
char modelName[256];
llama_model_meta_val_str(model.get(), "general.name", modelName, sizeof(modelName));
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
#endif
}
void worker_t::loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const {
auto *context = llama_new_context_with_model(mModel_.get(), mParams_);
while (!driver.stop_requested()) {
const auto generation_context = backlog.front();
generate(context, generation_context, std::nullopt);
backlog.pop();
SPDLOG_DEBUG("Processed request ({:d} remaining)", backlog.size());
}
llama_free(context);
}
size_t worker_t::generate(
llama_context *context,
const generation_context_t &generation_context,
const std::optional<llama_decode_callback> &callback) const {
std::expected<size_t, backend_error_t>
worker_t::generate(const generation_context_t &generation_context,
const std::optional<llama_decode_callback> &callback) const {
// Store information about context and generation size
const auto callback_ = callback.value_or(llama_void_callback);
auto max_new_tokens = generation_context.generation_params.max_new_tokens;
// Convert sampling params to what llama.cpp is looking for
auto sampler = generation_context.sampling_params.into_llama_sampler(mModel_.get());
auto sampler = generation_context.sampling_params.into_llama_sampler(model_.get());
// Set up the prompt
auto copy = std::vector(generation_context.input_tokens.begin(), generation_context.input_tokens.end());
@ -94,11 +79,10 @@ namespace huggingface::tgi::backends::llamacpp {
// Decode
auto n_decoded_tokens = 0;
for (bool generating = true; generating; ++n_decoded_tokens) {
const auto callback_ = callback.value_or(llama_void_callback);
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
const auto start = std::chrono::steady_clock::now();
const auto status = llama_decode(context, batch);
const auto status = llama_decode(context_.get(), batch);
const auto end = std::chrono::steady_clock::now();
const auto latency = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
SPDLOG_DEBUG(FMT_STRING("Successfully decoded {:d} token(s) in {}"), batch.n_tokens, latency);
@ -108,8 +92,8 @@ namespace huggingface::tgi::backends::llamacpp {
batch.n_tokens = 0;
if (LLAMA_SUCCESS(status)) [[likely]] {
// Sample the new token
auto new_token_id = llama_sampler_sample(sampler.get(), context, -1);
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
auto new_token_id = llama_sampler_sample(sampler.get(), context_.get(), -1);
auto is_eog = llama_token_is_eog(model_.get(), new_token_id);
auto new_token_logits = 0.0f; // TODO: return logit
// Handle termination cases
@ -119,11 +103,8 @@ namespace huggingface::tgi::backends::llamacpp {
generating = !(has_reach_max_tokens | has_reach_eog);
// Bubble up the generated token if a callback is provided
const auto should_stop = std::invoke(std::forward<const llama_decode_callback>(callback_),
new_token_id,
new_token_logits,
!generating,
n_decoded_tokens + 1);
const auto should_stop =
std::invoke(callback_, new_token_id, new_token_logits, !generating, n_decoded_tokens + 1);
generating ^= should_stop;
batch = llama_batch_get_one(&new_token_id, 1);
@ -132,62 +113,4 @@ namespace huggingface::tgi::backends::llamacpp {
return n_decoded_tokens;
}
backend_base_t::backend_base_t(llama_model *model) : mModel_(model, llama_free_model) { llama_backend_init(); }
backend_base_t::~backend_base_t() { llama_backend_free(); }
std::expected<std::vector<llama_token>, backend_error_t> backend_base_t::generate(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback
) {
// TODO: Should we provide a way to change this value?
auto generated = std::vector<llama_token>(2 << 8);
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
size_t num_generated_tokens) -> bool {
generated.emplace_back(new_token_id);
if (callback.has_value())
return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
return true;
};
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
return generated;
}
/** Single worker_t Backend impl **/
single_worker_backend_t::single_worker_backend_t(llama_model *model,
const std::optional<llama_context_params> &params)
: backend_base_t(model),
mContext_(llama_context_factory(model)),
mWorker_(mModel_, params.value_or(llama_context_default_params())) {
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
};
std::expected<size_t, backend_error_t>
single_worker_backend_t::stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) {
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback);
}
std::expected<size_t, backend_error_t>
multi_worker_backend_t::stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) {
SPDLOG_WARN("Not implemented for multi_worker_t");
return 0;
}
}

View File

@ -76,8 +76,8 @@ namespace huggingface::tgi::backends::llamacpp {
*/
class worker_t {
private:
const std::shared_ptr<llama_model> mModel_;
const llama_context_params mParams_;
std::shared_ptr<llama_model> model_;
llama_context_ptr context_;
public:
/**
@ -85,7 +85,7 @@ namespace huggingface::tgi::backends::llamacpp {
* @param model
* @param params
*/
worker_t(std::shared_ptr<llama_model> model, const llama_context_params &params);
worker_t(std::shared_ptr<llama_model>, const llama_context_params &);
/**
*
@ -93,108 +93,8 @@ namespace huggingface::tgi::backends::llamacpp {
* @param generation_context
* @param callback
*/
size_t
generate(llama_context *, const generation_context_t &, const std::optional<llama_decode_callback> &) const;
/**
*
*/
void loop(std::stop_source &driver, std::queue<generation_context_t> &backlog) const;
};
class backend_base_t {
protected:
std::shared_ptr<llama_model> mModel_;
public:
/**
*
* @param model
*/
explicit backend_base_t(llama_model *model);
/**
* Destructor
*/
~backend_base_t();
/**
*
* @param tokens
* @param generation_params
* @param sampling_params
* @param callback
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
std::expected<std::vector<llama_token>, backend_error_t> generate(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback = std::nullopt
);
/**
*
* @param tokens
* @param generation_params
* @param sampling_params
* @params callback
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
virtual std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) = 0;
};
class single_worker_backend_t : backend_base_t {
private:
constexpr static auto llama_context_factory = [](llama_model *pModel) -> llama_context_ptr {
auto llParams = llama_context_default_params();
llParams.flash_attn = true;
llParams.n_batch = 1;
llParams.n_threads = 1;
llParams.no_perf = true;
llParams.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL;
return {llama_new_context_with_model(pModel, llParams), llama_context_deleter};
};
llama_context_ptr mContext_;
worker_t mWorker_;
public:
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
using backend_base_t::generate;
std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback) override;
};
class multi_worker_backend_t : backend_base_t {
private:
llama_context_ptr mContext_;
public:
using backend_base_t::generate;
std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback) override;
[[nodiscard]] std::expected<size_t, backend_error_t>
generate(const generation_context_t &, const std::optional<llama_decode_callback> &) const;
};
}

View File

@ -7,58 +7,41 @@
#include <exception>
#include <filesystem>
#include <memory>
#include <string_view>
#include <variant>
#include <spdlog/spdlog.h>
#include "backend.hpp"
namespace huggingface::tgi::backends::llamacpp {
struct generation_params_t;
struct sampling_params_t;
class llama_cpp_backend_impl_t;
class llama_cpp_worker_frontend_t;
}
#include "backend.hpp"
#include "backends/llamacpp/src/lib.rs.h"
#include "rust/cxx.h"
namespace huggingface::tgi::backends::llamacpp {
// Concept identifying types which have a .generate() -> size_t method to do in-place generation
template<typename T>
concept has_stream_method = requires(
T t,
std::span<const llama_token> input_tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
llama_decode_callback callback
) {
{
t.stream(input_tokens, generation_params, sampling_params, callback)
} -> std::same_as<std::expected<size_t, backend_error_t>>;
auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
auto make_shared_llama_model = [](llama_model *model) {
return std::shared_ptr<llama_model>(model, llama_model_deleter);
};
static_assert(has_stream_method<single_worker_backend_t>, "single_worker_backend_t doesn't meet concept has_stream_method");
static_assert(has_stream_method<multi_worker_backend_t>, "multi_worker_backend_t doesn't meet concept has_stream_method");
class llama_cpp_backend_exception_t : std::exception {
};
class llama_cpp_backend_exception_t : std::exception {};
/**
* Llama.cpp backend interfacing with Rust FFI layer
* Llama.cpp frontend over the worker interfacing with Rust FFI layer
*/
class llama_cpp_backend_impl_t {
class llama_cpp_worker_frontend_t {
private:
std::variant<single_worker_backend_t, multi_worker_backend_t> mInner_;
std::shared_ptr<llama_model> model_;
worker_t worker_;
public:
explicit llama_cpp_backend_impl_t(single_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
explicit llama_cpp_worker_frontend_t(llama_model *model):
model_{ make_shared_llama_model(model) }, worker_(model_, {.no_perf = true}) {}
size_t stream(
rust::Slice<const uint32_t> input_tokens,
@ -67,41 +50,31 @@ namespace huggingface::tgi::backends::llamacpp {
InferContext *ctx,
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
-> std::expected<size_t, backend_error_t> {
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
};
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
return backend.stream(
input_tokens_v,
generation_params,
sampling_params,
context_forwarding_callback
);
auto context_forwarding_callback =
[=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
};
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to llama_token*
auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
const auto generation_context = generation_context_t {generation_params, sampling_params, input_tokens_v};
if(const auto result = worker_.generate(generation_context, context_forwarding_callback); result.has_value()) [[likely]] {
return *result;
} else {
throw llama_cpp_backend_exception_t();
throw llama_cpp_backend_exception_t {};
}
}
};
std::unique_ptr<llama_cpp_backend_impl_t> create_single_worker_backend(rust::Str modelPath) {
std::unique_ptr<llama_cpp_worker_frontend_t> create_worker_frontend(rust::Str modelPath) {
const auto cxxPath = std::string(modelPath);
auto params = llama_model_default_params();
params.use_mmap = true;
auto *model = llama_load_model_from_file(cxxPath.c_str(), params);
return std::make_unique<llama_cpp_backend_impl_t>(single_worker_backend_t { model, std::nullopt });
auto *model = (llama_load_model_from_file(cxxPath.c_str(), params));
return std::make_unique<llama_cpp_worker_frontend_t>(model);
}
}

View File

@ -1,16 +1,17 @@
//
// Created by mfuntowicz on 10/3/24.
//
#include <memory>
#include <fmt/color.h>
#include <fmt/format.h>
#include <fmt/std.h>
#include <fmt/ranges.h>
#include <llama.h>
#include <spdlog/spdlog.h>
#include <spdlog/fmt/ranges.h>s
#include "../csrc/backend.hpp"
using namespace huggingface::tgi::backends::llamacpp;
const auto llama_model_deleter = [](llama_model *model) { llama_free_model(model); };
int main(int argc, char **argv) {
if (argc < 2) {
fmt::print("No model folder provider");
@ -18,21 +19,31 @@ int main(int argc, char **argv) {
}
spdlog::set_level(spdlog::level::debug);
const auto modelPath = absolute(std::filesystem::path(argv[1]));
const auto params = llama_model_default_params();
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
auto model = std::unique_ptr<llama_model, decltype(llama_model_deleter)>(
llama_load_model_from_file(modelPath.c_str(), params)
);
auto backend = single_worker_backend_t(model, {});
auto prompt = "My name is Morgan";
auto tokens = std::vector<llama_token>(16);
const auto nb_tokens = llama_tokenize(model.get(), prompt, sizeof(prompt), tokens.data(), tokens.size(), true,
false);
tokens.resize(nb_tokens);
auto backend = worker_t{std::move(model), {.n_batch = 1, .n_threads = 4}};
fmt::println("Tokenized: {}", tokens);
// generate
const auto promptTokens = {128000, 5159, 836, 374, 23809, 11};
const auto out = backend.generate(promptTokens, {.max_new_tokens = 32}, {.top_k = 40});
if (out.has_value())
fmt::print(FMT_STRING("Generated: {}"), *out);
else {
const auto err = out.error();
fmt::print(fmt::emphasis::bold | fg(fmt::color::red), "Got an error: {:d}", static_cast<uint8_t>(err));
}
auto generated_tokens = std::vector<llama_token>(32);
const auto n_generated_tokens = backend.generate(
{{.max_new_tokens = 32}, {.top_k = 40}, tokens},
[&generated_tokens](llama_token new_token_id, float_t logit, bool is_eos, size_t step) -> bool {
generated_tokens.emplace(generated_tokens.begin() + (step - 1), new_token_id);
return false;
}
);
generated_tokens.resize(n_generated_tokens.value());
fmt::println("Generated {} tokens", generated_tokens);
}

View File

@ -1,8 +1,9 @@
use crate::ffi::{
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
create_worker_frontend, GenerationParams, LlamaCppWorkerFrontend, SamplingParams,
};
use async_trait::async_trait;
use cxx::UniquePtr;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::Arc;
@ -21,7 +22,7 @@ use tracing::{debug, error, info};
type InferResult = Result<InferStreamResponse, InferError>;
unsafe impl Send for LlamaCppBackendImpl {}
unsafe impl Send for LlamaCppWorkerFrontend {}
impl From<&ValidParameters> for SamplingParams {
fn from(v: &ValidParameters) -> Self {
@ -68,41 +69,54 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String),
}
pub struct LlamaCppBackend {
backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
_scheduler_handle: JoinHandle<()>,
// pub struct LlamaCppBackend {
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
// _scheduler_handle: JoinHandle<()>,
// }
struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
handle: JoinHandle<()>,
}
pub enum LlamaCppBackend {
Single(LlamaCppWorker),
// Multi(Vec<LlamaCppWorker>)
}
impl LlamaCppBackend {
pub fn new<P: AsRef<Path> + Send>(
fn allocate_worker(
path: &Path,
) -> Result<UniquePtr<LlamaCppWorkerFrontend>, LlamaCppBackendError> {
create_worker_frontend(&path.display().to_string()).map_err(|ref err| {
LlamaCppBackendError::ModelInitializationFailed(path.to_path_buf(), err.to_string())
})
}
pub fn new<P: AsRef<Path>>(
model_path: P,
tokenizer: Tokenizer,
num_cores_per_instance: u16,
) -> Result<Self, LlamaCppBackendError> {
let path = Arc::new(model_path.as_ref());
let shared_path = Arc::new(model_path);
let path = shared_path.deref().as_ref();
if !path.exists() {
return Err(LlamaCppBackendError::ModelFileDoesntExist(
path.display().to_string(),
));
}
let backend = create_single_worker_backend(path.to_str().unwrap()).map_err(|err| {
LlamaCppBackendError::ModelInitializationFailed(
path.to_path_buf(),
err.what().to_string(),
)
})?;
let worker = match num_cores_per_instance {
0 => {
let worker = Self::allocate_worker(path)?;
let (sender, receiver) = channel();
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
}
_ => panic!("No supported yet"),
};
info!(
"Successfully initialized llama.cpp backend from {}",
path.display()
);
let (submitter, receiver) = channel();
let handle = unsafe { spawn(|| scheduler_loop(backend, tokenizer, receiver)) };
Ok(Self {
backlog: submitter,
_scheduler_handle: handle,
})
Ok(worker)
}
}
@ -169,18 +183,16 @@ fn llama_generate_callback(
};
// Send back to the client
let should_stop = if let Err(ref _err) = ctx.stream.send(response) {
if let Err(ref _err) = ctx.stream.send(response) {
error!("Failed to send back the response to the client, cancelling request");
true
} else {
true
};
should_stop
false
}
}
unsafe fn scheduler_loop(
mut backend: UniquePtr<LlamaCppBackendImpl>,
fn scheduler_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
tokenizer: Tokenizer,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) {
@ -204,20 +216,23 @@ unsafe fn scheduler_loop(
generation,
});
let boxed_ctx = Box::into_raw(ctx);
// We leak the box to avoid it being freed after the first callback call
// when going out of scope
unsafe {
let boxed_ctx = Box::into_raw(ctx);
if let Err(e) = backend.pin_mut().stream(
&input_tokens,
generation_params,
&sampling_params,
boxed_ctx,
llama_generate_callback,
) {
error!("Error while decoding tokens... {}", e.what());
}
if let Err(e) = backend.pin_mut().stream(
&input_tokens,
generation_params,
&sampling_params,
boxed_ctx,
llama_generate_callback,
) {
error!("Error while decoding tokens... {}", e.what());
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
}
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
}
} else {
info!("IPC channel is closed, exiting the scheduler loop");
@ -244,11 +259,13 @@ impl Backend for LlamaCppBackend {
sampling_params,
};
match self.backlog.send((ctx, sx)) {
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
Err(_) => Err(InferError::GenerationError(
"Failed to sent the request".to_string(),
)),
match self {
LlamaCppBackend::Single(worker) => match worker.sender.send((ctx, sx)) {
Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
Err(_) => Err(InferError::GenerationError(
"Failed to sent the request".to_string(),
)),
},
}
} else {
Err(InferError::GenerationError(

View File

@ -46,14 +46,13 @@ mod ffi {
type SamplingParams;
/// Represent an instance of the llama.cpp backend instance on C++ side
#[cxx_name = "llama_cpp_backend_impl_t"]
type LlamaCppBackendImpl;
#[cxx_name = "llama_cpp_worker_frontend_t"]
type LlamaCppWorkerFrontend;
#[rust_name = "create_single_worker_backend"]
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
fn create_worker_frontend(modelPath: &str) -> Result<UniquePtr<LlamaCppWorkerFrontend>>;
unsafe fn stream(
self: Pin<&mut LlamaCppBackendImpl>,
self: Pin<&mut LlamaCppWorkerFrontend>,
tokens: &[u32],
generation_params: GenerationParams,
sampling_params: &SamplingParams,

View File

@ -37,8 +37,8 @@ struct Args {
port: u16,
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
gguf_path: PathBuf,
// #[clap(long, env, default_value = "1", help = "Number of model instance(s)")]
// num_model_instance: u16,
#[clap(long, env, help = "Number of CPU core per instance(s)")]
num_cores_per_instance: Option<u16>,
#[clap(long, env, required = true)]
tokenizer_name: String,
#[clap(long, env)]
@ -95,7 +95,7 @@ async fn main() -> Result<(), RouterError> {
hostname,
port,
gguf_path,
// num_model_instance,
num_cores_per_instance,
tokenizer_name,
tokenizer_config_path,
revision,
@ -164,7 +164,7 @@ async fn main() -> Result<(), RouterError> {
};
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer");
let backend = LlamaCppBackend::new(gguf_path, tokenizer)?;
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
// Run server
server::run(