feat: perfer loraxs custom punica kernels and add mlp loras

This commit is contained in:
drbh 2024-06-06 15:57:00 +00:00
parent d5f21d57d1
commit 8984ce6c69
5 changed files with 35 additions and 10 deletions

View File

@ -0,0 +1,9 @@
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
lorax-punica: install-lorax-punica
git clone --no-checkout https://github.com/predibase/lorax.git
install-lorax-punica:
cd lorax && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax && git submodule update --init --recursive
cd lorax/server/punica_kernels && python setup.py install

View File

@ -226,7 +226,9 @@ class FlashLlamaAttention(torch.nn.Module):
max_s, max_s,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
@ -295,7 +297,7 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored. # TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
@ -309,11 +311,13 @@ class LlamaMLP(nn.Module):
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
@ -373,7 +377,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res

View File

@ -63,6 +63,7 @@ class Model(ABC):
world_size: int = 1, world_size: int = 1,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
): ):
self.model_id = model_id self.model_id = model_id
self.model = model.eval() self.model = model.eval()
@ -87,6 +88,19 @@ class Model(ABC):
) )
self.target_to_layer = self.adapter_target_to_layer() self.target_to_layer = self.adapter_target_to_layer()
self.loaded_adapters = set() self.loaded_adapters = set()
self.static_adapter_id = adapter_id
# TODO: review moving adapter loading to the model
if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID:
pass
# download_adapter(adapter_id, adapter_source, api_token=None)
# self.load_adapter(
# AdapterParameters(adapter_ids=[adapter_id]),
# adapter_source,
# adapter_index=0,
# api_token=None,
# dynamic=False,
# )
if speculate is None: if speculate is None:
speculate = get_speculate() speculate = get_speculate()

View File

@ -251,7 +251,7 @@ def serve(
majority_sign_method=0, majority_sign_method=0,
) )
adapter_source = None adapter_source = None
adapter_index = None adapter_index = 0
api_token = None api_token = None
model.load_adapter( model.load_adapter(

View File

@ -7,9 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
try: try:
# TODO: add build steps for Punica kernels import punica_kernels as _kernels
# import punica_kernels as _kernels
import punica.ops as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError: except ImportError: