[Backend] Add Llamacpp backend (#2975)

* Add llamacpp backend

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Get rid of llama_batch_get_one()

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Use max_batch_total_tokens

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Handle max_batch_size

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add some input validation checks

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Handle ctx args & fix sampling

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add GPU args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --defrag-threshold

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add a stupid batch mechanism

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --numa

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Enable flash attention by default

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --offload-kqv

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix batch_pos

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* backend(llama): add CUDA Dockerfile_llamacpp for now

* Only export the latest logits

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Output real logprobs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix batching

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix seq iterations

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Auto-detect n_threads when not provided

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Clear request cache after completion

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove warmup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* backend(llama): add CUDA architectures build argument for Dockerfile

* Add specific args for batch

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add --type-v & --type-k

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llamacpp to b4623

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Disable graceful shutdown in debug mode

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update Dockerfile_llamacpp

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup Dockerfile

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update Cargo.lock

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Simplify batching logic

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Set TGI_LLAMA_PKG_CUDA from CUDA_VERSION

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Rename bindings

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove n_ctx

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Make max_batch_total_tokens optional

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Ensure all samplers are freed on error

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Initialize penalty_last_n with llamacpp default value

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Improve default settings

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update docs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Thanks clippy

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Thanks cargo fmt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update docs

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Do not use HOSTNAME env

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Bump llama.cpp & cuda

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix requirements.txt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix fmt

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Enable KQV offload by default

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove Ngrok tunneling

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove .cargo/config.toml

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix Dockerfile

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add missing cuda prefix

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Handle custom llama.cpp dir

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Cleanup

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add README.md

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Add HF transfer

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Fix bool args

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Update doc

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Adrien Gallouët 2025-02-14 13:40:57 +01:00 committed by GitHub
parent 6df0fc0b55
commit cfd4fbb479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1765 additions and 422 deletions

927
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,6 +5,7 @@ members = [
"backends/v3",
"backends/grpc-metadata",
"backends/trtllm",
"backends/llamacpp",
"launcher",
"router"
]

76
Dockerfile_llamacpp Normal file
View File

@ -0,0 +1,76 @@
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps
ARG llamacpp_version=b4651
ARG llamacpp_cuda=OFF
ARG cuda_arch=75-real;80-real;86-real;89-real;90-real
WORKDIR /opt/src
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt install -y \
clang \
cmake \
curl \
git \
python3-dev \
libssl-dev \
pkg-config \
tar
ADD https://github.com/ggerganov/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/
RUN tar -xzf ${llamacpp_version}.tar.gz \
&& cd llama.cpp-${llamacpp_version} \
&& cmake -B build \
-DCMAKE_INSTALL_PREFIX=/usr \
-DCMAKE_INSTALL_LIBDIR=/usr/lib \
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \
-DGGML_CUDA=${llamacpp_cuda} \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_SERVER=OFF \
&& cmake --build build --parallel --config Release \
&& cmake --install build
WORKDIR /app
COPY rust-toolchain.toml rust-toolchain.toml
RUN curl -sSf https://sh.rustup.rs | sh -s -- -y --no-modify-path --default-toolchain none
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef --locked
FROM deps AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
FROM deps AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN cargo chef cook \
--recipe-path recipe.json \
--profile release-opt \
--package text-generation-router-llamacpp
COPY . .
RUN cargo build \
--profile release-opt \
--package text-generation-router-llamacpp --frozen
FROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
RUN apt update && apt install -y \
python3-venv \
python3-pip
RUN python3 -m venv /venv
ENV PATH="/venv/bin:$PATH"
COPY backends/llamacpp/requirements.txt requirements.txt
RUN pip3 install --no-cache-dir -r requirements.txt
COPY --from=builder /usr/lib/libllama.so /usr/lib/
COPY --from=builder /usr/lib/libggml*.so /usr/lib/
COPY --from=builder /app/target/release-opt/text-generation-router-llamacpp /usr/bin/
ENV HF_HUB_ENABLE_HF_TRANSFER=1
ENTRYPOINT ["text-generation-router-llamacpp"]

View File

@ -0,0 +1,21 @@
[package]
name = "text-generation-router-llamacpp"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[build-dependencies]
bindgen = "0.71.1"
pkg-config = "0.3.31"
[dependencies]
async-trait = "0.1.85"
clap = "4.5.27"
num_cpus = "1.16.0"
text-generation-router = { path = "../../router" }
thiserror = "2.0.11"
tokenizers.workspace = true
tokio = "1.43.0"
tokio-stream = "0.1.17"
tracing = "0.1.41"

View File

@ -0,0 +1,24 @@
# Llamacpp backend
If all your dependencies are installed at the system level, running
cargo build should be sufficient. However, if you want to experiment
with different versions of llama.cpp, some additional setup is required.
## Install llama.cpp
LLAMACPP_PREFIX=$(pwd)/llama.cpp.out
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
cmake -B build \
-DCMAKE_INSTALL_PREFIX="$LLAMACPP_PREFIX" \
-DLLAMA_BUILD_COMMON=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_SERVER=OFF
cmake --build build --config Release -j
cmake --install build
## Build TGI
PKG_CONFIG_PATH="$LLAMACPP_PREFIX/lib/pkgconfig" cargo build

View File

@ -0,0 +1,48 @@
use bindgen::callbacks::{ItemInfo, ParseCallbacks};
use std::env;
use std::path::PathBuf;
#[derive(Debug)]
struct PrefixStripper;
impl ParseCallbacks for PrefixStripper {
fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {
item_info.name.strip_prefix("llama_").map(str::to_string)
}
}
fn main() {
if let Some(cuda_version) = option_env!("CUDA_VERSION") {
let mut version: Vec<&str> = cuda_version.split('.').collect();
if version.len() > 2 {
version.pop();
}
let cuda_version = format!("cuda-{}", version.join("."));
pkg_config::Config::new().probe(&cuda_version).unwrap();
}
let llama = pkg_config::Config::new().probe("llama").unwrap();
for path in &llama.link_paths {
println!("cargo:rustc-link-arg=-Wl,-rpath,{}", path.display());
}
println!("cargo:rustc-link-arg=-Wl,--disable-new-dtags");
let bindings = bindgen::Builder::default()
.clang_args(
llama
.include_paths
.iter()
.map(|p| format!("-I{}", p.display())),
)
.header_contents("llama_bindings.h", "#include <llama.h>")
.prepend_enum_name(false)
.parse_callbacks(Box::new(PrefixStripper))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate()
.expect("Unable to generate bindings");
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("llamacpp.rs"))
.expect("Couldn't write bindings!");
}

View File

@ -0,0 +1,3 @@
transformers==4.48.2
huggingface-hub==0.28.1
hf-transfer==0.1.9

View File

@ -0,0 +1,679 @@
mod llamacpp {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/llamacpp.rs"));
}
use async_trait::async_trait;
use std::ffi::CString;
use std::mem::replace;
use std::str::FromStr;
use std::sync::{mpsc, Once};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, Token};
use thiserror::Error;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::{oneshot, watch};
use tokio::task::{spawn, spawn_blocking};
use tokio::time::{timeout, Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::instrument;
use tracing::{debug, error, info, trace, warn};
#[derive(Debug, Clone, Copy)]
pub enum LlamacppSplitMode {
GPU(usize),
Layer,
Row,
}
impl FromStr for LlamacppSplitMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"layer" => Ok(LlamacppSplitMode::Layer),
"row" => Ok(LlamacppSplitMode::Row),
_ => match s.parse::<usize>() {
Ok(n) => Ok(LlamacppSplitMode::GPU(n)),
Err(_) => Err("Choose a GPU number or `layer` or `row`".to_string()),
},
}
}
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum LlamacppNuma {
Disabled,
Distribute,
Isolate,
Numactl,
Mirror,
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum LlamacppGGMLType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2_K,
Q3_K,
Q4_K,
Q5_K,
Q6_K,
Q8_K,
IQ2_XXS,
IQ2_XS,
IQ3_XXS,
IQ1_S,
IQ4_NL,
IQ3_S,
IQ2_S,
IQ4_XS,
I8,
I16,
I32,
I64,
F64,
IQ1_M,
BF16,
TQ1_0,
TQ2_0,
}
// TODO: macro
impl LlamacppGGMLType {
fn to_ggml_type(self) -> llamacpp::ggml_type {
match self {
LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,
LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,
LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,
LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,
LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,
LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,
LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,
LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,
LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,
LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,
LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,
LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,
LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,
LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,
LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,
LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,
LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,
LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,
LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,
LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,
LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,
LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,
LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,
LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,
LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,
LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,
LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,
LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,
LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,
LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,
LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,
}
}
}
pub struct LlamacppConfig {
pub model_gguf: String,
pub max_batch_total_tokens: usize,
pub max_physical_batch_total_tokens: usize,
pub max_batch_size: usize,
pub batch_timeout: Duration,
pub n_threads: usize,
pub n_threads_batch: usize,
pub n_gpu_layers: usize,
pub split_mode: LlamacppSplitMode,
pub numa: LlamacppNuma,
pub defrag_threshold: f32,
pub use_mmap: bool,
pub use_mlock: bool,
pub offload_kqv: bool,
pub flash_attention: bool,
pub type_k: LlamacppGGMLType,
pub type_v: LlamacppGGMLType,
}
#[derive(Debug)]
struct LlamacppRequest {
input_ids: Vec<i32>,
top_k: i32,
top_p: f32,
typical_p: f32,
min_keep: usize,
temp: f32,
seed: u32,
penalty_last_n: i32,
penalty_repeat: f32,
penalty_freq: f32,
penalty_present: f32,
max_new_tokens: usize,
tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
time: Instant,
}
pub struct LlamacppBackend {
tx: UnboundedSender<LlamacppRequest>,
status: watch::Receiver<bool>,
}
impl LlamacppRequest {
fn new(
from: &ValidGenerateRequest,
tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
) -> Option<Self> {
from.input_ids.as_ref().map(|input_ids| LlamacppRequest {
input_ids: input_ids.iter().map(|&x| x as i32).collect(),
top_k: from.parameters.top_k as _,
top_p: from.parameters.top_p as _,
typical_p: from.parameters.typical_p as _,
min_keep: 0, // disabled
temp: from.parameters.temperature as _,
seed: from.parameters.seed as _,
penalty_last_n: 64, // 0 = disabled, -1 = context size
penalty_repeat: from.parameters.repetition_penalty as _,
penalty_freq: from.parameters.frequency_penalty as _,
penalty_present: 0.0, // disabled
max_new_tokens: from.stopping_parameters.max_new_tokens as _,
tx,
time: Instant::now(),
})
}
}
struct Llamacpp {
model: *mut llamacpp::llama_model,
ctx: *mut llamacpp::llama_context,
vocab: *const llamacpp::llama_vocab,
logprobs: Vec<llamacpp::llama_token_data>,
batch: llamacpp::llama_batch,
}
extern "C" fn llamacpp_log_callback(
level: llamacpp::ggml_log_level,
msg: *const std::os::raw::c_char,
_user_data: *mut std::os::raw::c_void,
) {
let cmsg = unsafe { std::ffi::CStr::from_ptr(msg) };
let rmsg = cmsg.to_string_lossy().trim_end_matches('\n').to_string();
match level {
llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_INFO => info!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: "llamacpp", "{}", rmsg),
llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: "llamacpp", "{}", rmsg),
_ => trace!(target: "llamacpp", "{}", rmsg),
}
}
impl Llamacpp {
fn new(conf: LlamacppConfig) -> Result<Self, BackendError> {
let gguf = CString::new(conf.model_gguf)?;
let model = unsafe {
let mut params = llamacpp::model_default_params();
params.n_gpu_layers = conf.n_gpu_layers as _;
params.split_mode = match conf.split_mode {
LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,
LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,
LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,
};
params.main_gpu = match conf.split_mode {
LlamacppSplitMode::GPU(n) => n as _,
_ => 0,
};
params.use_mmap = conf.use_mmap;
params.use_mlock = conf.use_mlock;
llamacpp::model_load_from_file(gguf.as_ptr(), params)
};
if model.is_null() {
return Err(BackendError::Llamacpp("Failed to load model".to_string()));
}
let ctx = unsafe {
let mut params = llamacpp::context_default_params();
params.n_ctx = conf.max_batch_total_tokens as _;
params.n_batch = conf.max_batch_total_tokens as _;
params.n_ubatch = conf.max_physical_batch_total_tokens as _;
params.n_seq_max = conf.max_batch_size as _;
params.n_threads = conf.n_threads as _;
params.n_threads_batch = conf.n_threads_batch as _;
params.defrag_thold = conf.defrag_threshold;
params.offload_kqv = conf.offload_kqv;
params.flash_attn = conf.flash_attention;
params.type_k = conf.type_k.to_ggml_type();
params.type_v = conf.type_v.to_ggml_type();
params.no_perf = true;
llamacpp::init_from_model(model, params)
};
if ctx.is_null() {
return Err(BackendError::Llamacpp("Failed to init context".to_string()));
}
let vocab = unsafe { llamacpp::model_get_vocab(model) };
if vocab.is_null() {
return Err(BackendError::Llamacpp("Failed to get vocab".to_string()));
}
let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };
let mut logprobs = Vec::with_capacity(n_tokens as usize);
for token in 0..n_tokens {
logprobs.push(llamacpp::llama_token_data {
id: token,
logit: 0.0,
p: 0.0,
});
}
let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };
Ok(Llamacpp {
model,
ctx,
vocab,
logprobs,
batch,
})
}
fn decode(&mut self) -> i32 {
unsafe { llamacpp::decode(self.ctx, self.batch) }
}
fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {
unsafe {
llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);
}
}
fn batch_push(
&mut self,
token: llamacpp::llama_token,
pos: llamacpp::llama_pos,
seq_id: llamacpp::llama_seq_id,
logits: bool,
) -> usize {
let n = self.batch.n_tokens as usize;
unsafe {
*self.batch.token.add(n) = token;
*self.batch.pos.add(n) = pos;
*self.batch.n_seq_id.add(n) = 1;
*(*self.batch.seq_id.add(n)).add(0) = seq_id;
*self.batch.logits.add(n) = logits as i8;
}
self.batch.n_tokens += 1;
n
}
}
impl Drop for Llamacpp {
fn drop(&mut self) {
if !self.ctx.is_null() {
unsafe { llamacpp::free(self.ctx) };
}
if !self.model.is_null() {
unsafe { llamacpp::model_free(self.model) };
}
unsafe { llamacpp::batch_free(self.batch) };
}
}
struct LlamacppSampler {
chain: *mut llamacpp::llama_sampler,
}
impl LlamacppSampler {
fn new(req: &LlamacppRequest) -> Option<Self> {
let chain = unsafe {
let params = llamacpp::sampler_chain_default_params();
llamacpp::sampler_chain_init(params)
};
if chain.is_null() {
error!("Failed to init sampler");
return None;
}
let (top_k, top_p, typical_p, temp, penalties, dist) = unsafe {
(
llamacpp::sampler_init_top_k(req.top_k),
llamacpp::sampler_init_top_p(req.top_p, req.min_keep),
llamacpp::sampler_init_typical(req.typical_p, req.min_keep),
llamacpp::sampler_init_temp(req.temp),
llamacpp::sampler_init_penalties(
req.penalty_last_n,
req.penalty_repeat,
req.penalty_freq,
req.penalty_present,
),
llamacpp::sampler_init_dist(req.seed),
)
};
let all = &[
("top_k", top_k),
("top_p", top_p),
("typical_p", typical_p),
("temp", temp),
("penalties", penalties),
("dist", dist),
];
let mut failed = false;
for (k, v) in all {
if v.is_null() {
error!("Failed to init {k} sampler");
failed = true;
} else {
unsafe { llamacpp::sampler_chain_add(chain, *v) };
}
}
if failed {
unsafe { llamacpp::sampler_free(chain) };
None
} else {
Some(LlamacppSampler { chain })
}
}
fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {
let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };
for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {
*logprob = llamacpp::llama_token_data {
id: token as _,
logit: unsafe { *logits.add(token) },
p: 0.0,
};
}
let mut view = llamacpp::llama_token_data_array {
data: llamacpp.logprobs.as_mut_ptr(),
size: llamacpp.logprobs.len(),
selected: -1,
sorted: false,
};
unsafe {
llamacpp::sampler_apply(self.chain, &mut view);
let logprob = *view.data.offset(view.selected as _);
llamacpp::sampler_accept(self.chain, logprob.id);
(logprob.id, logprob.p.ln())
}
}
}
impl Drop for LlamacppSampler {
fn drop(&mut self) {
if !self.chain.is_null() {
unsafe { llamacpp::sampler_free(self.chain) };
}
}
}
struct LlamacppSeq {
id: usize,
batch_pos: usize,
token: llamacpp::llama_token,
pos: llamacpp::llama_pos,
sampler: LlamacppSampler,
text: String,
n_new_tokens: usize,
running: bool,
}
static INIT: Once = Once::new();
impl LlamacppBackend {
pub fn new(
conf: LlamacppConfig,
tokenizer: Tokenizer,
) -> (
Self,
oneshot::Receiver<Result<(), BackendError>>,
watch::Sender<bool>,
) {
// Setup llama & export logs, once and for all
INIT.call_once(|| unsafe {
llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());
llamacpp::backend_init();
llamacpp::numa_init(match conf.numa {
LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,
LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,
LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,
LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,
LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,
});
});
let (status_tx, status_rx) = watch::channel(false);
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let (ok_tx, ok_rx) = oneshot::channel();
let (tx, mut rx) = unbounded_channel::<LlamacppRequest>();
let (sync_tx, sync_rx) = mpsc::channel();
spawn(async move {
let mut n_tokens = 0;
let mut requests = Vec::with_capacity(conf.max_batch_size);
let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {
if !requests.is_empty() {
let _ =
sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));
*n_tokens = 0;
}
};
loop {
match timeout(conf.batch_timeout, rx.recv()).await {
Ok(Some(request)) => {
let n_tokens_to_add = request.input_ids.len();
if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens {
flush(&mut requests, &mut n_tokens);
}
n_tokens += n_tokens_to_add;
requests.push(request);
if requests.len() == conf.max_batch_size {
flush(&mut requests, &mut n_tokens);
}
}
Ok(None) => break, // closed
Err(_) => flush(&mut requests, &mut n_tokens), // timeout
}
}
});
spawn_blocking(move || {
let mut llamacpp = match Llamacpp::new(conf) {
Ok(v) => {
let _ = ok_tx.send(Ok(()));
v
}
Err(e) => {
let _ = ok_tx.send(Err(e));
return;
}
};
let vocab = tokenizer.get_added_vocabulary();
// health() returns true
let _ = status_tx.send(true);
while let Ok(requests) = sync_rx.recv() {
if *shutdown_rx.borrow() {
break;
}
let start_time = Instant::now();
let mut seqs: Vec<LlamacppSeq> = Vec::with_capacity(requests.len());
llamacpp.batch.n_tokens = 0;
for (seq_id, request) in requests.iter().enumerate() {
debug!("Request: {:?}", request);
// TODO remove this
let sampler = match LlamacppSampler::new(request) {
Some(sampler) => sampler,
_ => {
let _ = request.tx.send(Err(InferError::IncompleteGeneration));
continue;
}
};
let last_pos = request.input_ids.len() - 1;
for (pos, &token_id) in request.input_ids.iter().enumerate() {
llamacpp.batch_push(
token_id as llamacpp::llama_token,
pos as llamacpp::llama_pos,
seq_id as llamacpp::llama_seq_id,
pos == last_pos, // check samplers
);
}
seqs.push(LlamacppSeq {
id: seq_id,
batch_pos: llamacpp.batch.n_tokens as usize - 1,
token: llamacpp::LLAMA_TOKEN_NULL,
pos: last_pos as llamacpp::llama_pos + 1,
sampler,
text: String::with_capacity(1024),
n_new_tokens: 0,
running: true,
});
}
while llamacpp.batch.n_tokens > 0 {
if llamacpp.decode() != 0 {
warn!("llama_decode failed, clearing kv cache");
llamacpp.clear_kv_cache(-1);
for seq in seqs.iter_mut() {
let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false;
}
break;
}
for seq in seqs.iter_mut() {
if !seq.running {
continue;
}
let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);
seq.n_new_tokens += 1;
seq.token = next;
let piece = match tokenizer.decode(&[next as u32], false) {
Ok(piece) => piece,
Err(e) => {
error!("Failed to decode token: {e}");
let _ = requests[seq.id]
.tx
.send(Err(InferError::IncompleteGeneration));
seq.running = false;
continue;
}
};
let special = vocab.is_special_token(&piece);
if !special {
seq.text.push_str(&piece);
}
let token = Token {
id: next as _,
text: piece,
logprob,
special,
};
let finish: Option<FinishReason> = {
if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {
Some(FinishReason::EndOfSequenceToken)
} else if seq.n_new_tokens == requests[seq.id].max_new_tokens {
Some(FinishReason::Length)
} else {
None
}
};
if let Some(reason) = finish {
let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: seq.text.clone(),
generated_tokens: seq.n_new_tokens as _,
finish_reason: reason,
seed: Some(requests[seq.id].seed as _),
},
start: start_time,
queued: requests[seq.id].time,
}));
seq.running = false;
continue;
}
let _ = requests[seq.id]
.tx
.send(Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}));
}
// generate a new batch
llamacpp.batch.n_tokens = 0;
for seq in seqs.iter_mut() {
if seq.running {
seq.batch_pos =
llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);
seq.pos += 1;
} else {
llamacpp.clear_kv_cache(seq.id as _);
}
}
}
}
});
(
Self {
tx,
status: status_rx,
},
ok_rx,
shutdown_tx,
)
}
}
#[async_trait]
impl Backend for LlamacppBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
debug!(?request);
let (tx, rx) = unbounded_channel::<Result<InferStreamResponse, InferError>>();
match LlamacppRequest::new(&request, tx) {
Some(v) => match self.tx.send(v) {
Err(e) => Err(InferError::GenerationError(e.to_string())),
_ => Ok(UnboundedReceiverStream::new(rx)),
},
_ => Err(InferError::GenerationError("Bad request".to_string())),
}
}
async fn health(&self, _: bool) -> bool {
*self.status.borrow()
}
fn name(&self) -> &'static str {
"llamacpp"
}
}
#[derive(Debug, Error)]
pub enum BackendError {
#[error("CString error: {0}")]
CStringError(#[from] std::ffi::NulError),
#[error("Llamacpp error: {0}")]
Llamacpp(String),
}

View File

@ -0,0 +1,284 @@
mod backend;
use backend::{
BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,
LlamacppSplitMode,
};
use clap::Parser;
use text_generation_router::{logging, server, usage_stats};
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokio::sync::oneshot::error::RecvError;
use tracing::{error, warn};
/// Backend Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// Name of the model to load.
#[clap(long, env)]
model_id: String,
/// Revision of the model.
#[clap(default_value = "main", long, env)]
revision: String,
/// Path to the GGUF model file for inference.
#[clap(long, env)]
model_gguf: String, // TODO Option() with hf->gguf & quantize
/// Number of threads to use for generation.
#[clap(long, env)]
n_threads: Option<usize>,
/// Number of threads to use for batch processing.
#[clap(long, env)]
n_threads_batch: Option<usize>,
/// Number of layers to store in VRAM.
#[clap(default_value = "0", long, env)]
n_gpu_layers: usize,
/// Split the model across multiple GPUs.
#[clap(default_value = "layer", long, env)]
split_mode: LlamacppSplitMode,
/// Defragment the KV cache if holes/size > threshold.
#[clap(default_value = "-1.0", long, env)]
defrag_threshold: f32,
/// Enable NUMA optimizations.
#[clap(default_value = "disabled", value_enum, long, env)]
numa: LlamacppNuma,
/// Use memory mapping for the model.
#[clap(long, env)]
use_mmap: bool,
/// Use memory locking to prevent swapping.
#[clap(long, env)]
use_mlock: bool,
/// Enable offloading of KQV operations to the GPU.
#[clap(long, env)]
offload_kqv: bool,
/// Enable flash attention for faster inference. (EXPERIMENTAL)
#[clap(long, env)]
flash_attention: bool,
/// Data type used for K cache.
#[clap(default_value = "f16", value_enum, long, env)]
type_k: LlamacppGGMLType,
/// Data type used for V cache.
#[clap(default_value = "f16", value_enum, long, env)]
type_v: LlamacppGGMLType,
/// Number of tokenizer workers used for payload validation and truncation.
#[clap(default_value = "2", long, env)]
validation_workers: usize,
/// Maximum number of concurrent requests.
#[clap(long, env)]
max_concurrent_requests: Option<usize>,
/// Maximum number of input tokens per request.
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
/// Maximum number of total tokens (input + output) per request.
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
/// Maximum number of tokens in a batch.
#[clap(long, env)]
max_batch_total_tokens: Option<usize>,
/// Maximum number of tokens in a physical batch.
#[clap(long, env)]
max_physical_batch_total_tokens: Option<usize>,
/// Maximum number of requests per batch.
#[clap(long, env)]
max_batch_size: Option<usize>,
/// IP address to listen on.
#[clap(default_value = "0.0.0.0", long)]
hostname: String,
/// Port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16,
/// Enable JSON output format.
#[clap(long, env)]
json_output: bool,
/// OTLP endpoint for telemetry data.
#[clap(long, env)]
otlp_endpoint: Option<String>,
/// Service name for OTLP telemetry.
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
/// Allowed origins for CORS.
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
/// Path to the tokenizer configuration file.
#[clap(long, env)]
tokenizer_config_path: Option<String>,
/// Disable grammar support.
#[clap(long, env)]
disable_grammar_support: bool,
/// Maximum number of inputs per request.
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
/// Level of usage statistics collection.
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
/// Maximum payload size in bytes.
#[clap(default_value = "2000000", long, env)]
payload_limit: usize,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);
let n_threads = match args.n_threads {
Some(0) | None => num_cpus::get(),
Some(threads) => threads,
};
let n_threads_batch = match args.n_threads_batch {
Some(0) | None => n_threads,
Some(threads) => threads,
};
let max_batch_size = match args.max_batch_size {
Some(0) | None => n_threads_batch,
Some(threads) => threads,
};
let max_batch_total_tokens = match args.max_batch_total_tokens {
None => max_batch_size * args.max_total_tokens,
Some(size) => size,
};
let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {
None => max_batch_total_tokens,
Some(size) => size,
};
let max_concurrent_requests = match args.max_concurrent_requests {
None => max_batch_size * 2,
Some(size) => size,
};
if args.max_input_tokens >= args.max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
if max_batch_size * args.max_total_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`".to_string(),
));
}
// TODO: check if we use the same cache of Server
// check if llamacpp is faster
let tokenizer = {
let token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
let params = FromPretrainedParameters {
revision: args.revision.clone(),
token,
..Default::default()
};
Tokenizer::from_pretrained(args.model_id.clone(), Some(params))?
};
let (backend, ok, shutdown) = LlamacppBackend::new(
LlamacppConfig {
model_gguf: args.model_gguf,
n_threads,
n_threads_batch,
n_gpu_layers: args.n_gpu_layers,
split_mode: args.split_mode,
defrag_threshold: args.defrag_threshold,
numa: args.numa,
use_mmap: args.use_mmap,
use_mlock: args.use_mlock,
flash_attention: args.flash_attention,
type_k: args.type_k,
type_v: args.type_v,
offload_kqv: args.offload_kqv,
max_batch_total_tokens,
max_physical_batch_total_tokens,
max_batch_size,
batch_timeout: tokio::time::Duration::from_millis(5),
},
tokenizer,
);
ok.await??;
if cfg!(debug_assertions) {
warn!("Graceful shutdown disabled!");
let _ = tokio::task::spawn(async move {
let _ = tokio::signal::ctrl_c().await;
let _ = shutdown.send(true);
});
}
server::run(
backend,
max_concurrent_requests,
0, // max_best_of
0, // max_stop_sequences
0, // max_top_n_tokens
args.max_input_tokens,
args.max_total_tokens,
args.validation_workers,
None, // api_key
args.model_id, // tokenizer_name
args.tokenizer_config_path,
Some(args.revision),
false, // trust_remote_code
args.hostname,
args.port,
args.cors_allow_origin,
false, // ngrok,
None, // ngrok_authtoken,
None, // ngrok_edge,
args.disable_grammar_support,
args.max_client_batch_size,
args.usage_stats,
args.payload_limit,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::Error),
#[error("Backend error: {0}")]
Backend(#[from] BackendError),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Recv error: {0}")]
RecvError(#[from] RecvError),
}

View File

@ -52,6 +52,8 @@
- sections:
- local: backends/trtllm
title: TensorRT-LLM
- local: backends/llamacpp
title: Llamacpp
title: Backends
- sections:
- local: reference/launcher

View File

@ -0,0 +1,120 @@
# Llamacpp Backend
The llamacpp backend facilitates the deployment of large language models
(LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine
optimized for both CPU and GPU computation. This backend is a component
of Hugging Faces **Text Generation Inference (TGI)** suite,
specifically designed to streamline the deployment of LLMs in production
environments.
## Key Capabilities
- Full compatibility with GGUF format and all quantization formats
(GGUF-related constraints may be mitigated dynamically by on-the-fly
generation in future updates)
- Optimized inference on CPU and GPU architectures
- Containerized deployment, eliminating dependency complexity
- Seamless interoperability with the Hugging Face ecosystem
## Model Compatibility
This backend leverages models formatted in **GGUF**, providing an
optimized balance between computational efficiency and model accuracy.
You will find the best models on [Hugging Face][GGUF].
## Build Docker image
For optimal performance, the Docker image is compiled with native CPU
instructions, thus it's highly recommended to execute the container on
the host used during the build process. Efforts are ongoing to enhance
portability while maintaining high computational efficiency.
```bash
docker build \
-t tgi-llamacpp \
https://github.com/huggingface/text-generation-inference.git \
-f Dockerfile_llamacpp
```
### Build parameters
| Parameter | Description |
| ------------------------------------ | --------------------------------- |
| `--build-arg llamacpp_version=bXXXX` | Specific version of llama.cpp |
| `--build-arg llamacpp_cuda=ON` | Enables CUDA acceleration |
| `--build-arg cuda_arch=ARCH` | Defines target CUDA architecture |
## Model preparation
Retrieve a GGUF model and store it in a specific directory, for example:
```bash
mkdir -p ~/models
cd ~/models
curl -LOJ "https://huggingface.co/Qwen/Qwen2.5-3B-Instruct-GGUF/resolve/main/qwen2.5-3b-instruct-q4_0.gguf?download=true"
```
## Run Docker image
### CPU-based inference
```bash
docker run \
-p 3000:3000 \
-e "HF_TOKEN=$HF_TOKEN" \
-v "$HOME/models:/models" \
tgi-llamacpp \
--model-id "Qwen/Qwen2.5-3B-Instruct" \
--model-gguf "/models/qwen2.5-3b-instruct-q4_0.gguf"
```
### GPU-Accelerated inference
```bash
docker run \
--gpus all \
-p 3000:3000 \
-e "HF_TOKEN=$HF_TOKEN" \
-v "$HOME/models:/models" \
tgi-llamacpp \
--n-gpu-layers 99
--model-id "Qwen/Qwen2.5-3B-Instruct" \
--model-gguf "/models/qwen2.5-3b-instruct-q4_0.gguf"
```
## Advanced parameters
A full listing of configurable parameters is available in the `--help`:
```bash
docker run tgi-llamacpp --help
```
The table below summarizes key options:
| Parameter | Description |
|-------------------------------------|------------------------------------------------------------------------|
| `--n-threads` | Number of threads to use for generation |
| `--n-threads-batch` | Number of threads to use for batch processing |
| `--n-gpu-layers` | Number of layers to store in VRAM |
| `--split-mode` | Split the model across multiple GPUs |
| `--defrag-threshold` | Defragment the KV cache if holes/size > threshold |
| `--numa` | Enable NUMA optimizations |
| `--use-mmap` | Use memory mapping for the model |
| `--use-mlock` | Use memory locking to prevent swapping |
| `--offload-kqv` | Enable offloading of KQV operations to the GPU |
| `--flash-attention` | Enable flash attention for faster inference |
| `--type-k` | Data type used for K cache |
| `--type-v` | Data type used for V cache |
| `--validation-workers` | Number of tokenizer workers used for payload validation and truncation |
| `--max-concurrent-requests` | Maximum number of concurrent requests |
| `--max-input-tokens` | Maximum number of input tokens per request |
| `--max-total-tokens` | Maximum number of total tokens (input + output) per request |
| `--max-batch-total-tokens` | Maximum number of tokens in a batch |
| `--max-physical-batch-total-tokens` | Maximum number of tokens in a physical batch |
| `--max-batch-size` | Maximum number of requests per batch |
---
[llama.cpp]: https://github.com/ggerganov/llama.cpp
[GGUF]: https://huggingface.co/models?library=gguf&sort=trending

View File

@ -11,3 +11,5 @@ TGI remains consistent across backends, allowing you to switch between them seam
* **[TGI TRTLLM backend](./backends/trtllm)**: This backend leverages NVIDIA's TensorRT library to accelerate LLM inference.
It utilizes specialized optimizations and custom kernels for enhanced performance.
However, it requires a model-specific compilation step for each GPU architecture.
* **[TGI Llamacpp backend](./backends/llamacpp)**: This backend facilitates the deployment of large language models
(LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine optimized for both CPU and GPU computation.