From 8984ce6c69fe42eed2d265d9656abd6a433a19c3 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Jun 2024 15:57:00 +0000 Subject: [PATCH] feat: perfer loraxs custom punica kernels and add mlp loras --- server/Makefile-lorax-punica | 9 +++++++++ .../custom_modeling/flash_llama_modeling.py | 16 ++++++++++------ server/text_generation_server/models/model.py | 14 ++++++++++++++ server/text_generation_server/server.py | 2 +- server/text_generation_server/utils/sgmv.py | 4 +--- 5 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 server/Makefile-lorax-punica diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica new file mode 100644 index 00000000..2ba8e5a7 --- /dev/null +++ b/server/Makefile-lorax-punica @@ -0,0 +1,9 @@ +lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc + +lorax-punica: install-lorax-punica + git clone --no-checkout https://github.com/predibase/lorax.git + +install-lorax-punica: + cd lorax && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) + cd lorax && git submodule update --init --recursive + cd lorax/server/punica_kernels && python setup.py install diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 42a28cc6..c84e2290 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -226,7 +226,9 @@ class FlashLlamaAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class LlamaMLP(nn.Module): @@ -295,7 +297,7 @@ class LlamaMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize - def forward(self, hidden_states): + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" @@ -309,11 +311,13 @@ class LlamaMLP(nn.Module): device="cuda", ) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out) + return self.down_proj(out, adapter_data) else: - gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): @@ -373,7 +377,7 @@ class FlashLlamaLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index a5ef7908..ca259737 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -63,6 +63,7 @@ class Model(ABC): world_size: int = 1, sliding_window: Optional[int] = None, speculate: Optional[int] = None, + adapter_id: str = BASE_MODEL_ADAPTER_ID, ): self.model_id = model_id self.model = model.eval() @@ -87,6 +88,19 @@ class Model(ABC): ) self.target_to_layer = self.adapter_target_to_layer() self.loaded_adapters = set() + self.static_adapter_id = adapter_id + + # TODO: review moving adapter loading to the model + if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID: + pass + # download_adapter(adapter_id, adapter_source, api_token=None) + # self.load_adapter( + # AdapterParameters(adapter_ids=[adapter_id]), + # adapter_source, + # adapter_index=0, + # api_token=None, + # dynamic=False, + # ) if speculate is None: speculate = get_speculate() diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4a059776..bc410ced 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -251,7 +251,7 @@ def serve( majority_sign_method=0, ) adapter_source = None - adapter_index = None + adapter_index = 0 api_token = None model.load_adapter( diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py index f6551e0f..d9e0d400 100644 --- a/server/text_generation_server/utils/sgmv.py +++ b/server/text_generation_server/utils/sgmv.py @@ -7,9 +7,7 @@ import torch import torch.nn.functional as F try: - # TODO: add build steps for Punica kernels - # import punica_kernels as _kernels - import punica.ops as _kernels + import punica_kernels as _kernels HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) except ImportError: