# Rust builder FROM lukemathwalker/cargo-chef:latest-rust-1.80.1 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse FROM chef AS planner COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router COPY backends backends COPY launcher launcher RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ python3.11-dev RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ rm -f $PROTOC_ZIP COPY --from=planner /usr/src/recipe.json recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json ARG GIT_SHA ARG DOCKER_LABEL COPY Cargo.lock Cargo.lock COPY Cargo.toml Cargo.toml COPY rust-toolchain.toml rust-toolchain.toml COPY proto proto COPY benchmark benchmark COPY router router COPY backends backends COPY launcher launcher RUN cargo build --profile release-opt --frozen # Text Generation Inference base image for RoCm FROM rocm/dev-ubuntu-22.04:6.2 AS base RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ ca-certificates \ ccache \ curl \ git \ make \ libmsgpack-dev \ libssl-dev \ llvm-dev \ g++ \ # Needed to build VLLM & flash. rocthrust-dev \ hipsparse-dev \ hipblas-dev \ hipcub-dev \ rocblas-dev \ hiprand-dev \ hipfft-dev \ rocrand-dev \ miopen-hip-dev \ hipsolver-dev \ rccl-dev \ cmake \ python3.11-venv && \ rm -rf /var/lib/apt/lists/* # Keep in sync with `server/pyproject.toml ARG MAMBA_VERSION=23.1.0-1 ARG PYTHON_VERSION='3.11.10' # Automatically set by buildx ARG TARGETPLATFORM ENV PATH=/opt/conda/bin:$PATH ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # Install mamba # translating Docker's TARGETPLATFORM into mamba arches RUN case ${TARGETPLATFORM} in \ "linux/arm64") MAMBA_ARCH=aarch64 ;; \ *) MAMBA_ARCH=x86_64 ;; \ esac && \ curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" RUN chmod +x ~/mambaforge.sh && \ bash ~/mambaforge.sh -b -p /opt/conda && \ mamba init && \ rm ~/mambaforge.sh # RUN conda install intel::mkl-static intel::mkl-include # Install pytorch # On arm64 we exit with an error code RUN case ${TARGETPLATFORM} in \ "linux/arm64") exit 1 ;; \ *) /opt/conda/bin/conda update -y conda && \ /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ esac && \ /opt/conda/bin/conda clean -ya # Install flash-attention, torch dependencies RUN python3 -m pip install --upgrade pip && pip install numpy einops ninja joblib msgpack cmake --no-cache-dir && rm -rf /var/lib/apt/lists/* RUN conda install mkl=2021 ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/opt/conda/lib/python3.11/site-packages/torch/lib:/opt/conda/lib/ ARG COMMON_WORKDIR=/ WORKDIR ${COMMON_WORKDIR} # Install HIPBLASLt FROM base AS build_hipblaslt ARG HIPBLASLT_BRANCH="e6da924" RUN git clone https://github.com/ROCm/hipBLASLt.git \ && cd hipBLASLt \ && git checkout ${HIPBLASLT_BRANCH} \ && SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} --legacy_hipblas_direct \ && cd build/release \ && make package FROM scratch AS export_hipblaslt ARG COMMON_WORKDIR COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb / # RCCL build stages FROM base AS build_rccl ARG RCCL_BRANCH="rocm-6.2.0" RUN git clone https://github.com/ROCm/rccl \ && cd rccl \ && git checkout ${RCCL_BRANCH} \ && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} FROM scratch AS export_rccl ARG COMMON_WORKDIR COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb / # Triton build stages FROM base AS build_triton ARG TRITON_BRANCH="e192dba" ARG TRITON_REPO="https://github.com/triton-lang/triton.git" RUN python3 -m pip install ninja cmake wheel pybind11 && git clone ${TRITON_REPO} \ && cd triton \ && git checkout ${TRITON_BRANCH} \ && cd python \ && python3 setup.py bdist_wheel --dist-dir=dist FROM scratch AS export_triton ARG COMMON_WORKDIR COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / # # AMD-SMI build stages FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ && pip wheel . --wheel-dir=dist FROM scratch AS export_amdsmi COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / FROM base as build_pytorch RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ if ls /install/*.deb; then \ dpkg -i /install/*.deb \ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ fi ARG BUILD_ENVIRONMENT=pytorch-linux-jammy-rocm6.2-py3.11 ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942" # A commit to fix the output scaling factor issue in _scaled_mm # Not yet in 2.5.0-rc1 ARG PYTORCH_BRANCH="cedc116" ARG PYTORCH_VISION_BRANCH="v0.19.1" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" RUN git clone ${PYTORCH_REPO} pytorch \ && cd pytorch && git checkout ${PYTORCH_BRANCH} && git submodule update --init --recursive \ && pip install -r requirements.txt --no-cache-dir \ && python tools/amd_build/build_amd.py \ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist FROM scratch as export_pytorch ARG COMMON_WORKDIR COPY --from=build_pytorch ${COMMON_WORKDIR}/pytorch/dist/*.whl / FROM base AS install_deps ARG COMMON_WORKDIR # Install hipblaslt RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ if ls /install/*.deb; then \ dpkg -i /install/*.deb \ && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ fi RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ if ls /install/*.deb; then \ dpkg -i /install/*.deb \ # RCCL needs to be installed twice && dpkg -i /install/*.deb \ && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \ fi RUN --mount=type=bind,from=export_triton,src=/,target=/install \ if ls /install/*.whl; then \ # Preemptively uninstall to prevent pip same-version no-installs pip uninstall -y triton \ && pip install /install/*.whl; \ fi RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ # Preemptively uninstall to prevent pip same-version no-installs pip uninstall -y amdsmi \ && pip install /install/*.whl; RUN --mount=type=bind,from=export_pytorch,src=/,target=/install \ if ls /install/*.whl; then \ # Preemptively uninstall to prevent pip same-version no-installs pip uninstall -y torch torchvision \ && pip install /install/*.whl; \ fi FROM install_deps AS kernel-builder # # Build vllm kernels FROM kernel-builder AS vllm-builder WORKDIR /usr/src COPY server/Makefile-vllm Makefile # Build specific version of vllm RUN make build-vllm-rocm # Build Flash Attention v2 kernels FROM kernel-builder AS flash-att-v2-builder WORKDIR /usr/src COPY server/Makefile-flash-att-v2 Makefile # Build specific version of flash attention v2 RUN make build-flash-attention-v2-rocm # Build Transformers CUDA kernels (gpt-neox and bloom) FROM kernel-builder AS custom-kernels-builder WORKDIR /usr/src COPY server/custom_kernels/ . RUN python setup.py build # Build exllama kernels FROM kernel-builder AS exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . RUN python setup.py build # Build exllama v2 kernels FROM kernel-builder AS exllamav2-kernels-builder WORKDIR /usr/src COPY server/exllamav2_kernels/ . RUN python setup.py build FROM install_deps AS base-copy # Text Generation Inference base env ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from flash attention v2 builder COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from custom kernels builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllama kernels builder COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Copy build artifacts from exllamav2 kernels builder COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages # Install server COPY proto proto COPY server server COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements_rocm.txt && \ pip install ".[accelerate, compressed-tensors, peft, outlines]" --no-cache-dir # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/" # AWS Sagemaker compatible image FROM base AS sagemaker COPY sagemaker-entrypoint.sh entrypoint.sh RUN chmod +x entrypoint.sh ENTRYPOINT ["./entrypoint.sh"] # Final image FROM base-copy # Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm ENV HIP_FORCE_DEV_KERNARG=1 # On MI250 and MI300, performances for flash with Triton FA are slightly better than CK. # However, Triton requires a tunning for each prompt length, which is prohibitive. ENV ROCM_USE_FLASH_ATTN_V2_TRITON=0 ENV ROCM_USE_CUSTOM_PAGED_ATTN=1 ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0 ENV VLLM_MOE_PADDING=0 ENV ATTENTION=paged ENV PREFIX_CACHING=0 ENV PREFILL_CHUNKING=0 ENV ROCM_USE_SKINNY_GEMM=1 COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] CMD ["--json-output"]