feat: build custom selective-scan kernels

This commit is contained in:
drbh 2024-01-29 20:09:38 +00:00
parent 1c32d53fc3
commit c2681b2bea
4 changed files with 33 additions and 0 deletions

1
server/.gitignore vendored
View File

@ -161,3 +161,4 @@ flash-attention-v2/
vllm/ vllm/
llm-awq/ llm-awq/
eetq/ eetq/
mamba/

View File

@ -3,6 +3,7 @@ include Makefile-flash-att-v2
include Makefile-vllm include Makefile-vllm
include Makefile-awq include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-selective-scan
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests

View File

@ -0,0 +1,26 @@
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
causal-conv1d:
rm -rf causal-conv1d
git clone https://github.com/Dao-AILab/causal-conv1d.git
build-causal-conv1d: causal-conv1d
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
install-causal-conv1d: build-causal-conv1d
pip uninstall causal-conv1d -y || true
cd causal-conv1d/ && pip install .
# selective-scan dependends on causal-conv1d
selective-scan:
rm -rf mamba
git clone https://github.com/state-spaces/mamba.git mamba
build-selective-scan: selective-scan
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
cd mamba && python setup.py build
install-selective-scan: install-causal-conv1d build-selective-scan
pip uninstall selective-scan-cuda -y || true
cd mamba && pip install .

View File

@ -13,6 +13,11 @@ from text_generation_server.utils.layers import (
) )
# note torch must be imported before the custom cuda modules
# since they rely on torch's libc10.so
import causal_conv1d_cuda
import selective_scan_cuda
class MambaConfig(PretrainedConfig): class MambaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,