Include flashinfer in the docker.

This commit is contained in:
Nicolas Patry 2024-08-16 23:50:37 +02:00
parent 52c813527a
commit 5336755358
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 20 additions and 0 deletions

View File

@ -184,6 +184,12 @@ WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN make build-all
# Build flashinfer
FROM kernel-builder AS flashinfer-builder
WORKDIR /usr/src
COPY server/Makefile-flashinfer Makefile
RUN make build-flashinfer
# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
@ -236,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
# Copy build artifacts from mamba builder
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
COPY --from=flashinfer-builder /usr/src/flashinfer/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

View File

@ -7,6 +7,7 @@ include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer
unit-tests:
pytest -s -vv -m "not private" tests

View File

@ -0,0 +1,12 @@
flashinfer_commit := v0.1.5
build-flashinfer:
git clone https://github.com/flashinfer-ai/flashinfer.git flashinfer && \
cd flashinfer && git fetch && git checkout $(flashinfer_commit) && \
git submodule update --init --recursive && \
cd python/ && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build
install-flashinfer: build-flashinfer
cd flashinfer/python/ && \
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install