feat: perfer loraxs custom punica kernels and add mlp loras
This commit is contained in:
parent
d5f21d57d1
commit
8984ce6c69
|
@ -0,0 +1,9 @@
|
||||||
|
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||||
|
|
||||||
|
lorax-punica: install-lorax-punica
|
||||||
|
git clone --no-checkout https://github.com/predibase/lorax.git
|
||||||
|
|
||||||
|
install-lorax-punica:
|
||||||
|
cd lorax && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||||
|
cd lorax && git submodule update --init --recursive
|
||||||
|
cd lorax/server/punica_kernels && python setup.py install
|
|
@ -226,7 +226,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
|
@ -295,7 +297,7 @@ class LlamaMLP(nn.Module):
|
||||||
# TODO: This is a hotfix to be removed & properly refactored.
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
|
@ -309,11 +311,13 @@ class LlamaMLP(nn.Module):
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
return self.down_proj(out)
|
return self.down_proj(out, adapter_data)
|
||||||
else:
|
else:
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
|
@ -373,7 +377,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
attn_output, res
|
attn_output, res
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output, attn_res
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,7 @@ class Model(ABC):
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
|
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
@ -87,6 +88,19 @@ class Model(ABC):
|
||||||
)
|
)
|
||||||
self.target_to_layer = self.adapter_target_to_layer()
|
self.target_to_layer = self.adapter_target_to_layer()
|
||||||
self.loaded_adapters = set()
|
self.loaded_adapters = set()
|
||||||
|
self.static_adapter_id = adapter_id
|
||||||
|
|
||||||
|
# TODO: review moving adapter loading to the model
|
||||||
|
if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID:
|
||||||
|
pass
|
||||||
|
# download_adapter(adapter_id, adapter_source, api_token=None)
|
||||||
|
# self.load_adapter(
|
||||||
|
# AdapterParameters(adapter_ids=[adapter_id]),
|
||||||
|
# adapter_source,
|
||||||
|
# adapter_index=0,
|
||||||
|
# api_token=None,
|
||||||
|
# dynamic=False,
|
||||||
|
# )
|
||||||
|
|
||||||
if speculate is None:
|
if speculate is None:
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
|
|
|
@ -251,7 +251,7 @@ def serve(
|
||||||
majority_sign_method=0,
|
majority_sign_method=0,
|
||||||
)
|
)
|
||||||
adapter_source = None
|
adapter_source = None
|
||||||
adapter_index = None
|
adapter_index = 0
|
||||||
api_token = None
|
api_token = None
|
||||||
|
|
||||||
model.load_adapter(
|
model.load_adapter(
|
||||||
|
|
|
@ -7,9 +7,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: add build steps for Punica kernels
|
import punica_kernels as _kernels
|
||||||
# import punica_kernels as _kernels
|
|
||||||
import punica.ops as _kernels
|
|
||||||
|
|
||||||
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
Loading…
Reference in New Issue