diff --git a/Cargo.lock b/Cargo.lock index bb63422e..d298c379 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4028,6 +4028,7 @@ dependencies = [ "cxx", "cxx-build", "log", + "parking_lot", "pkg-config", "text-generation-router", "thiserror", diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 7079d3d1..43a114ba 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -8,17 +8,18 @@ homepage.workspace = true [dependencies] async-trait = "0.1" async-stream = "0.3" +clap = { version = "4.5", features = ["derive"] } cxx = "1.0" +log = { version = "0.4", features = [] } text-generation-router = { path = "../../router" } tokenizers = { version = "0.19", features = ["hf-hub"] } tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.15" -clap = { version = "4.5", features = ["derive"] } thiserror = "1.0.62" tracing = "0.1" tracing-opentelemetry = "0.24" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -log = { version = "0.4", features = [] } +parking_lot = "0.12" [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile index 60ad03f7..5fd2f89f 100644 --- a/backends/trtllm/Dockerfile +++ b/backends/trtllm/Dockerfile @@ -3,7 +3,7 @@ ARG OMPI_VERSION="4.1.6" # Build dependencies resolver stage FROM lukemathwalker/cargo-chef:latest AS chef -WORKDIR /usr/src/text-generation-inference +WORKDIR /usr/src/text-generation-inference/backends/trtllm FROM chef AS planner COPY . . @@ -42,7 +42,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE mkdir /usr/src/mpi && \ tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ cd /usr/src/mpi && \ - ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \ make -j all && \ make install && \ rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" @@ -66,7 +66,7 @@ ENV PATH="/root/.cargo/bin:$PATH" RUN cargo install cargo-chef # Cache dependencies -COPY --from=planner /usr/src/text-generation-inference/recipe.json . +COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json . RUN cargo chef cook --release --recipe-path recipe.json # Build actual TGI @@ -79,7 +79,8 @@ COPY . . COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ - CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + cd backends/trtllm && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime WORKDIR /usr/local/tgi/bin diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index b26d06a6..b23aa6c0 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -12,12 +12,13 @@ use cxx::UniquePtr; use log::{error, warn}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::sync::RwLock; use tokio::time::{sleep, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::{Stream, StreamExt}; use tracing::{instrument, span, Level}; +// use tokio::sync::RwLock; +use parking_lot::RwLock; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 6d6ee146..e0ba46c7 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,12 +1,10 @@ +use clap::Parser; use std::collections::HashMap; use std::path::PathBuf; - -use clap::Parser; -use tokenizers::{FromPretrainedParameters, Tokenizer}; - use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_router::server; +use tokenizers::{FromPretrainedParameters, Tokenizer}; /// App Configuration #[derive(Parser, Debug)] @@ -160,6 +158,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { messages_api_enabled, true, max_client_batch_size, + false, + false, ) .await?; Ok(())