From c2681b2beac6615ec72857a9ed95f51ff781ab49 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 29 Jan 2024 20:09:38 +0000 Subject: [PATCH] feat: build custom selective-scan kernels --- server/.gitignore | 1 + server/Makefile | 1 + server/Makefile-selective-scan | 26 +++++++++++++++++++ .../models/custom_modeling/mamba_modeling.py | 5 ++++ 4 files changed, 33 insertions(+) create mode 100644 server/Makefile-selective-scan diff --git a/server/.gitignore b/server/.gitignore index dcb8fe67..576746ee 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -161,3 +161,4 @@ flash-attention-v2/ vllm/ llm-awq/ eetq/ +mamba/ diff --git a/server/Makefile b/server/Makefile index b1926828..31d55c41 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,6 +3,7 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-eetq +include Makefile-selective-scan unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-selective-scan b/server/Makefile-selective-scan new file mode 100644 index 00000000..cf724d55 --- /dev/null +++ b/server/Makefile-selective-scan @@ -0,0 +1,26 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 076d8dfd..7fb20344 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -13,6 +13,11 @@ from text_generation_server.utils.layers import ( ) +# note torch must be imported before the custom cuda modules +# since they rely on torch's libc10.so +import causal_conv1d_cuda +import selective_scan_cuda + class MambaConfig(PretrainedConfig): def __init__( self,