feat: perfer loraxs custom punica kernels and add mlp loras

This commit is contained in:
drbh 2024-06-06 15:57:00 +00:00
parent d5f21d57d1
commit 8984ce6c69
5 changed files with 35 additions and 10 deletions

View File

@ -0,0 +1,9 @@
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
lorax-punica: install-lorax-punica
git clone --no-checkout https://github.com/predibase/lorax.git
install-lorax-punica:
cd lorax && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax && git submodule update --init --recursive
cd lorax/server/punica_kernels && python setup.py install

View File

@ -226,7 +226,9 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class LlamaMLP(nn.Module):
@ -295,7 +297,7 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
@ -309,11 +311,13 @@ class LlamaMLP(nn.Module):
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module):
@ -373,7 +377,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res

View File

@ -63,6 +63,7 @@ class Model(ABC):
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
):
self.model_id = model_id
self.model = model.eval()
@ -87,6 +88,19 @@ class Model(ABC):
)
self.target_to_layer = self.adapter_target_to_layer()
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
# TODO: review moving adapter loading to the model
if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID:
pass
# download_adapter(adapter_id, adapter_source, api_token=None)
# self.load_adapter(
# AdapterParameters(adapter_ids=[adapter_id]),
# adapter_source,
# adapter_index=0,
# api_token=None,
# dynamic=False,
# )
if speculate is None:
speculate = get_speculate()

View File

@ -251,7 +251,7 @@ def serve(
majority_sign_method=0,
)
adapter_source = None
adapter_index = None
adapter_index = 0
api_token = None
model.load_adapter(

View File

@ -7,9 +7,7 @@ import torch
import torch.nn.functional as F
try:
# TODO: add build steps for Punica kernels
# import punica_kernels as _kernels
import punica.ops as _kernels
import punica_kernels as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError: