feat: build custom selective-scan kernels
This commit is contained in:
parent
1c32d53fc3
commit
c2681b2bea
|
@ -161,3 +161,4 @@ flash-attention-v2/
|
||||||
vllm/
|
vllm/
|
||||||
llm-awq/
|
llm-awq/
|
||||||
eetq/
|
eetq/
|
||||||
|
mamba/
|
||||||
|
|
|
@ -3,6 +3,7 @@ include Makefile-flash-att-v2
|
||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
|
include Makefile-selective-scan
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
|
||||||
|
|
||||||
|
causal-conv1d:
|
||||||
|
rm -rf causal-conv1d
|
||||||
|
git clone https://github.com/Dao-AILab/causal-conv1d.git
|
||||||
|
|
||||||
|
build-causal-conv1d: causal-conv1d
|
||||||
|
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
|
||||||
|
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
|
||||||
|
|
||||||
|
install-causal-conv1d: build-causal-conv1d
|
||||||
|
pip uninstall causal-conv1d -y || true
|
||||||
|
cd causal-conv1d/ && pip install .
|
||||||
|
|
||||||
|
# selective-scan dependends on causal-conv1d
|
||||||
|
selective-scan:
|
||||||
|
rm -rf mamba
|
||||||
|
git clone https://github.com/state-spaces/mamba.git mamba
|
||||||
|
|
||||||
|
build-selective-scan: selective-scan
|
||||||
|
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
|
||||||
|
cd mamba && python setup.py build
|
||||||
|
|
||||||
|
install-selective-scan: install-causal-conv1d build-selective-scan
|
||||||
|
pip uninstall selective-scan-cuda -y || true
|
||||||
|
cd mamba && pip install .
|
|
@ -13,6 +13,11 @@ from text_generation_server.utils.layers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# note torch must be imported before the custom cuda modules
|
||||||
|
# since they rely on torch's libc10.so
|
||||||
|
import causal_conv1d_cuda
|
||||||
|
import selective_scan_cuda
|
||||||
|
|
||||||
class MambaConfig(PretrainedConfig):
|
class MambaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue