Include flashinfer in the docker.
This commit is contained in:
parent
52c813527a
commit
5336755358
|
@ -184,6 +184,12 @@ WORKDIR /usr/src
|
|||
COPY server/Makefile-selective-scan Makefile
|
||||
RUN make build-all
|
||||
|
||||
# Build flashinfer
|
||||
FROM kernel-builder AS flashinfer-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-flashinfer Makefile
|
||||
RUN make build-flashinfer
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
||||
|
||||
|
@ -236,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
|
|||
# Copy build artifacts from mamba builder
|
||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=flashinfer-builder /usr/src/flashinfer/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
|
|
@ -7,6 +7,7 @@ include Makefile-selective-scan
|
|||
include Makefile-lorax-punica
|
||||
include Makefile-fbgemm
|
||||
include Makefile-exllamav2
|
||||
include Makefile-flashinfer
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
flashinfer_commit := v0.1.5
|
||||
|
||||
build-flashinfer:
|
||||
git clone https://github.com/flashinfer-ai/flashinfer.git flashinfer && \
|
||||
cd flashinfer && git fetch && git checkout $(flashinfer_commit) && \
|
||||
git submodule update --init --recursive && \
|
||||
cd python/ && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py build
|
||||
|
||||
install-flashinfer: build-flashinfer
|
||||
cd flashinfer/python/ && \
|
||||
CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py install
|
Loading…
Reference in New Issue