feat(llamacpp): expose number of threads for the backend when constructing the model
This commit is contained in:
parent
179309b364
commit
a316c53255
|
@ -17,12 +17,16 @@
|
|||
namespace huggingface::tgi::backends::llamacpp {
|
||||
[[nodiscard]]
|
||||
std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
|
||||
TgiLlamaCppBackend::FromGGUF(const std::filesystem::path &modelPath) noexcept {
|
||||
TgiLlamaCppBackend::FromGGUF(const std::filesystem::path &modelPath, const uint16_t nThreads) noexcept {
|
||||
SPDLOG_DEBUG(FMT_STRING("Loading model from {}"), modelPath);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
|
||||
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
llama_print_system_info();
|
||||
#endif
|
||||
|
||||
// Load the model
|
||||
if (!exists(modelPath)) {
|
||||
return std::unexpected(TgiLlamaCppBackendError::MODEL_FILE_DOESNT_EXIST);
|
||||
|
@ -32,7 +36,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
auto *model = llama_load_model_from_file(modelPath.c_str(), params);
|
||||
auto *context = llama_new_context_with_model(model, {
|
||||
.n_batch = 1,
|
||||
.n_threads = 16,
|
||||
.n_threads = nThreads,
|
||||
.attention_type = llama_attention_type::LLAMA_ATTENTION_TYPE_CAUSAL,
|
||||
.flash_attn = false,
|
||||
});
|
||||
|
@ -43,7 +47,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
huggingface::tgi::backends::llamacpp::TgiLlamaCppBackend::TgiLlamaCppBackend(llama_model *const model,
|
||||
llama_context *const ctx)
|
||||
: model(model), ctx(ctx) {
|
||||
#ifndef NDEBUG
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
char modelName[256];
|
||||
llama_model_meta_val_str(llama_get_model(ctx), "general.name", modelName, sizeof(modelName));
|
||||
SPDLOG_DEBUG(FMT_STRING("Created llama.cpp backend for model: '{}'"), std::string_view(modelName));
|
||||
|
@ -126,7 +130,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
|
||||
// Decode
|
||||
for (auto [generating, nDecoded] = std::pair{true, 0uz}; generating && nDecoded < maxNewTokens; ++nDecoded) {
|
||||
#ifndef NDEBUG
|
||||
#ifdef TGI_LLAMACPP_BACKEND_DEBUG
|
||||
const auto start = std::chrono::steady_clock::now();
|
||||
const auto status = llama_decode(ctx, batch);
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
|
|
@ -42,7 +42,7 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||
* @return
|
||||
*/
|
||||
static std::expected<std::pair<llama_model *, llama_context *>, TgiLlamaCppBackendError>
|
||||
FromGGUF(const std::filesystem::path &) noexcept;
|
||||
FromGGUF(const std::filesystem::path &, uint16_t) noexcept;
|
||||
|
||||
TgiLlamaCppBackend(llama_model *model, llama_context *ctx);
|
||||
|
||||
|
|
|
@ -34,9 +34,9 @@ namespace huggingface::tgi::backends::llamacpp::impl {
|
|||
LlamaCppBackendImpl(llama_model *model, llama_context *context) : _inner(model, context) {}
|
||||
};
|
||||
|
||||
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath) {
|
||||
std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath, uint16_t nThreads) {
|
||||
const auto cxxPath = std::string_view(modelPath);
|
||||
if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath)); maybe.has_value()) {
|
||||
if (auto maybe = TgiLlamaCppBackend::FromGGUF(std::filesystem::path(cxxPath), nThreads); maybe.has_value()) {
|
||||
auto [model, context] = *maybe;
|
||||
return std::make_unique<LlamaCppBackendImpl>(model, context);
|
||||
} else {
|
||||
|
|
|
@ -24,7 +24,10 @@ pub enum LlamaCppBackendError {
|
|||
pub struct LlamaCppBackend {}
|
||||
|
||||
impl LlamaCppBackend {
|
||||
pub fn new<P: AsRef<Path> + Send>(model_path: P) -> Result<Self, LlamaCppBackendError> {
|
||||
pub fn new<P: AsRef<Path> + Send>(
|
||||
model_path: P,
|
||||
n_threads: u16,
|
||||
) -> Result<Self, LlamaCppBackendError> {
|
||||
let path = Arc::new(model_path.as_ref());
|
||||
if !path.exists() {
|
||||
return Err(LlamaCppBackendError::ModelFileDoesntExist(
|
||||
|
@ -32,12 +35,13 @@ impl LlamaCppBackend {
|
|||
));
|
||||
}
|
||||
|
||||
let mut backend = create_llamacpp_backend(path.to_str().unwrap()).map_err(|err| {
|
||||
LlamaCppBackendError::ModelInitializationFailed(
|
||||
path.to_path_buf(),
|
||||
err.what().to_string(),
|
||||
)
|
||||
})?;
|
||||
let mut backend =
|
||||
create_llamacpp_backend(path.to_str().unwrap(), n_threads).map_err(|err| {
|
||||
LlamaCppBackendError::ModelInitializationFailed(
|
||||
path.to_path_buf(),
|
||||
err.what().to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
info!(
|
||||
"Successfully initialized llama.cpp backend from {}",
|
||||
|
|
|
@ -11,6 +11,7 @@ mod ffi {
|
|||
#[rust_name = "create_llamacpp_backend"]
|
||||
fn CreateLlamaCppBackendImpl(
|
||||
modelPath: &str,
|
||||
n_threads: u16,
|
||||
) -> Result<UniquePtr<LlamaCppBackendImpl>>;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,24 +23,25 @@ struct Args {
|
|||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(long, env, help = "Path to GGUF model file(s) to load")]
|
||||
gguf_path: PathBuf,
|
||||
#[clap(
|
||||
long,
|
||||
env,
|
||||
default_value = "1",
|
||||
help = "Number of CPU threads allocated to one llama.cpp model"
|
||||
)]
|
||||
cores_per_instance: u16,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
|
@ -93,15 +94,13 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
gguf_path,
|
||||
cores_per_instance,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
|
@ -162,7 +161,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
}
|
||||
|
||||
let backend = LlamaCppBackend::new(gguf_path)?;
|
||||
let backend = LlamaCppBackend::new(gguf_path, cores_per_instance)?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
|
|
Loading…
Reference in New Issue