feat: perfer loraxs custom punica kernels and add mlp loras
This commit is contained in:
parent
d5f21d57d1
commit
8984ce6c69
|
@ -0,0 +1,9 @@
|
|||
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||
|
||||
lorax-punica: install-lorax-punica
|
||||
git clone --no-checkout https://github.com/predibase/lorax.git
|
||||
|
||||
install-lorax-punica:
|
||||
cd lorax && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||
cd lorax && git submodule update --init --recursive
|
||||
cd lorax/server/punica_kernels && python setup.py install
|
|
@ -226,7 +226,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
@ -295,7 +297,7 @@ class LlamaMLP(nn.Module):
|
|||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
|
@ -309,11 +311,13 @@ class LlamaMLP(nn.Module):
|
|||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
|
@ -373,7 +377,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
|
|
@ -63,6 +63,7 @@ class Model(ABC):
|
|||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.model = model.eval()
|
||||
|
@ -87,6 +88,19 @@ class Model(ABC):
|
|||
)
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
self.loaded_adapters = set()
|
||||
self.static_adapter_id = adapter_id
|
||||
|
||||
# TODO: review moving adapter loading to the model
|
||||
if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID:
|
||||
pass
|
||||
# download_adapter(adapter_id, adapter_source, api_token=None)
|
||||
# self.load_adapter(
|
||||
# AdapterParameters(adapter_ids=[adapter_id]),
|
||||
# adapter_source,
|
||||
# adapter_index=0,
|
||||
# api_token=None,
|
||||
# dynamic=False,
|
||||
# )
|
||||
|
||||
if speculate is None:
|
||||
speculate = get_speculate()
|
||||
|
|
|
@ -251,7 +251,7 @@ def serve(
|
|||
majority_sign_method=0,
|
||||
)
|
||||
adapter_source = None
|
||||
adapter_index = None
|
||||
adapter_index = 0
|
||||
api_token = None
|
||||
|
||||
model.load_adapter(
|
||||
|
|
|
@ -7,9 +7,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
# TODO: add build steps for Punica kernels
|
||||
# import punica_kernels as _kernels
|
||||
import punica.ops as _kernels
|
||||
import punica_kernels as _kernels
|
||||
|
||||
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
||||
except ImportError:
|
||||
|
|
Loading…
Reference in New Issue