From aa88c4fd3afb34678d41fdf6e9d2cbf17846a086 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 14 Jun 2024 00:35:07 +0000 Subject: [PATCH] fix: add lora kernel to dockerfile, support running without kernels and refactors --- Dockerfile | 8 ++++++++ server/Makefile-lorax-punica | 15 +++++++++------ server/text_generation_server/models/model.py | 12 ------------ server/text_generation_server/utils/sgmv.py | 8 ++++++-- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/Dockerfile b/Dockerfile index f2f6df5f..0e1f423c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -144,6 +144,13 @@ COPY server/Makefile-marlin Makefile # Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-marlin +# Build Lorax Punica kernels +FROM kernel-builder as lorax-punica-builder +WORKDIR /usr/src +COPY server/Makefile-lorax-punica Makefile +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder WORKDIR /usr/src @@ -214,6 +221,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86 COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from marlin kernels builder COPY --from=marlin-kernels-builder /usr/src/marlin/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica index 2ba8e5a7..72f06f76 100644 --- a/server/Makefile-lorax-punica +++ b/server/Makefile-lorax-punica @@ -1,9 +1,12 @@ lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc -lorax-punica: install-lorax-punica - git clone --no-checkout https://github.com/predibase/lorax.git +build-lorax-punica: + if [ ! -d 'lorax-punica' ]; then \ + git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ + fi + cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) + cd lorax-punica && git submodule update --init --recursive + cd lorax-punica/server/punica_kernels && python setup.py build -install-lorax-punica: - cd lorax && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) - cd lorax && git submodule update --init --recursive - cd lorax/server/punica_kernels && python setup.py install +install-lorax-punica: build-lorax-punica + cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index ca259737..8da44273 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -90,18 +90,6 @@ class Model(ABC): self.loaded_adapters = set() self.static_adapter_id = adapter_id - # TODO: review moving adapter loading to the model - if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID: - pass - # download_adapter(adapter_id, adapter_source, api_token=None) - # self.load_adapter( - # AdapterParameters(adapter_ids=[adapter_id]), - # adapter_source, - # adapter_index=0, - # api_token=None, - # dynamic=False, - # ) - if speculate is None: speculate = get_speculate() self.speculate = speculate diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py index d9e0d400..7ad6288d 100644 --- a/server/text_generation_server/utils/sgmv.py +++ b/server/text_generation_server/utils/sgmv.py @@ -136,6 +136,10 @@ def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: return torch.empty((tmp_size,), dtype=torch.uint8, device=device) +def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: + return torch.empty((size,), dtype=torch.uint8, device=device) + + def get_tmp_expand_size(size: int) -> int: return _kernels.sgmv_cutlass_tmp_size(size) @@ -143,12 +147,12 @@ def get_tmp_expand_size(size: int) -> int: def get_tmp_tensors( nsegments: int, lora_rank: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: - if use_cutlass_shrink(lora_rank): + if use_cutlass_shrink(lora_rank) and has_sgmv(): tmp = get_tmp_tensor_for_size(nsegments, device) return tmp, tmp else: tmp_shrink = get_tmp_tensor(device) - tmp_expand = get_tmp_tensor_for_size(nsegments, device) + tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device) return tmp_shrink, tmp_expand