feat(docker): improve flash_attention caching (#160)

This commit is contained in:
OlivierDehaene 2023-04-09 19:59:16 +02:00 committed by GitHub
parent 3f2542bb6a
commit 1883d8ecde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 21 deletions

View File

@ -56,14 +56,16 @@ WORKDIR /usr/src
# Install torch # Install torch
RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
COPY server/Makefile server/Makefile
# Install specific version of flash attention # Install specific version of flash attention
COPY server/Makefile-flash-att server/Makefile
RUN cd server && make install-flash-attention RUN cd server && make install-flash-attention
# Install specific version of transformers # Install specific version of transformers
COPY server/Makefile-transformers server/Makefile
RUN cd server && BUILD_EXTENSIONS="True" make install-transformers RUN cd server && BUILD_EXTENSIONS="True" make install-transformers
COPY server/Makefile server/Makefile
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server

View File

@ -1,5 +1,5 @@
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef include Makefile-transformers
flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1 include Makefile-flash-att
gen-server: gen-server:
# Compile protos # Compile protos
@ -10,23 +10,6 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-transformers:
# Install specific version of transformers with custom cuda kernels
pip uninstall transformers -y || true
rm -rf transformers || true
git clone https://github.com/OlivierDehaene/transformers.git
cd transformers && git checkout $(transformers_commit)
cd transformers && python setup.py install
install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
install-torch: install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

10
server/Makefile-flash-att Normal file
View File

@ -0,0 +1,10 @@
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7
install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install

View File

@ -0,0 +1,10 @@
transformers_commit := b8d969ff47c6a9d40538a6ea33df021953363afc
install-transformers:
# Install specific version of transformers with custom cuda kernels
pip install --upgrade setuptools
pip uninstall transformers -y || true
rm -rf transformers || true
git clone https://github.com/OlivierDehaene/transformers.git
cd transformers && git checkout $(transformers_commit)
cd transformers && python setup.py install