From 0a6ea7fb57b10f5a307c575329ca0bddf5f28d03 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 4 Jun 2024 01:38:43 +0000 Subject: [PATCH] feat: load weights within layer and refactor lora pass --- .../custom_modeling/flash_llama_modeling.py | 149 ++++++++++-------- .../models/flash_causal_lm.py | 4 +- server/text_generation_server/utils/hub.py | 26 +++ 3 files changed, 109 insertions(+), 70 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 27703e14..b790896d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -126,27 +126,51 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) self.index = index self.adapter_weights = {} - for adapter_id, adapter_weights in all_adapter_weights.items(): - filtered_keys = list( - filter( - lambda x: x.startswith( - f"base_model.model.model.layers.{index}.self_attn" - ), - adapter_weights.keys(), - ) - ) - self.adapter_weights[adapter_id] = { - key: torch.tensor( - adapter_weights[key], - device=weights.device, - dtype=weights.dtype, - ).T - for key in filtered_keys - } + adapter_names = list(all_adapter_weights.keys()) - self.index_to_key = { - i: key for i, key in enumerate(self.adapter_weights.keys()) - } + self.lora_a_matrix = torch.empty( + (len(adapter_names), 2, 4096, 8), + device=weights.device, + dtype=weights.dtype, + ) + self.lora_b_matrix = torch.empty( + (len(adapter_names), 2, 8, 4096), + device=weights.device, + dtype=weights.dtype, + ) + + self.pre_multiplied_lora_matrix = torch.empty( + (len(adapter_names), 2, 4096, 4096), + device=weights.device, + dtype=weights.dtype, + ) + + self.key_to_index = {} + self.index_to_key = {} + + lora_prefix = f"base_model.model.model.layers.{index}.self_attn" + for adapter_index, adapter_name in enumerate(adapter_names): + self.lora_alpha = 16.0 + self.lora_r = 8.0 + self.lora_scale = self.lora_alpha / self.lora_r + self.key_to_index[adapter_name] = adapter_index + self.index_to_key[adapter_index] = adapter_name + adapter_weights = all_adapter_weights[adapter_name] + for target_index, target in enumerate(["q", "v"]): + adapter_weight_a = adapter_weights.get_tensor( + f"{lora_prefix}.{target}_proj.lora_A.weight" + ) + adapter_weight_b = adapter_weights.get_tensor( + f"{lora_prefix}.{target}_proj.lora_B.weight" + ) + pre_multiplied_lora_matrix = torch.matmul( + adapter_weight_a.T * self.lora_scale, + adapter_weight_b.T, + ).contiguous() + + self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = ( + pre_multiplied_lora_matrix + ) self.o_proj = TensorParallelRowLinear.load( config, @@ -159,23 +183,6 @@ class FlashLlamaAttention(torch.nn.Module): 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) - def get_adapter_weights(self, lora_index): - adapter_id = self.index_to_key[lora_index] - q_proj_lora_a = self.adapter_weights[adapter_id][ - f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight" - ] - q_proj_lora_b = self.adapter_weights[adapter_id][ - f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight" - ] - - v_proj_lora_a = self.adapter_weights[adapter_id][ - f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight" - ] - v_proj_lora_b = self.adapter_weights[adapter_id][ - f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight" - ] - return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b - def forward( self, hidden_states, @@ -201,39 +208,42 @@ class FlashLlamaAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = ( - self.get_adapter_weights( - # TODO: dont just assume the first adapter - lora_indices[0].item() - ) - ) - - query_adapted = torch.matmul( - hidden_states, - torch.matmul( - q_proj_lora_a, - q_proj_lora_b, - ), - ) - - value_adapted = torch.matmul( - hidden_states, - torch.matmul( - v_proj_lora_a, - v_proj_lora_b, - ), - ) - batch_size = query.size(0) + if not torch.all(lora_indices, -1): + lora_mask = lora_indices[lora_indices != -1] - # TODO: improve this to avoid unnecessary work - # mask across batch and within lora adapters - query[batch_lora_adapter_mask] += query_adapted.view( - batch_size, self.num_heads, self.head_size - )[batch_lora_adapter_mask] - kv[batch_lora_adapter_mask, 1] += value_adapted.view( - batch_size, self.num_key_value_heads, self.head_size - )[batch_lora_adapter_mask] + q_pre_multiplied_batch = torch.ones( + (batch_size, 4096, 4096), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[ + lora_mask, 0 + ] + + v_pre_multiplied_batch = torch.ones( + (batch_size, 4096, 4096), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[ + lora_mask, 1 + ] + + query_adapted = ( + torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch) + .squeeze(1) + .view(batch_size, self.num_heads, self.head_size) + ) + value_adapted = ( + torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch) + .squeeze(1) + .view(batch_size, self.num_key_value_heads, self.head_size) + ) + query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask] + kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask] self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) @@ -503,6 +513,9 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) + def get_lora_index(self, adapter_id): + return self.model.layers[0].self_attn.key_to_index[adapter_id] + def forward( self, input_ids: torch.Tensor, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 28527113..0062cd55 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1064,11 +1064,11 @@ class FlashCausalLM(Model): cuda_graph = None batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device) - lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device) + lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device) for i, r in enumerate(batch.requests): if r.adapter_id: - lora_index = int(r.adapter_id) + lora_index = self.model.get_lora_index(r.adapter_id) lora_indices[i] = lora_index batch_lora_adapter_mask[i] = True diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index b56484f6..d41700e8 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] +def _cached_adapter_weight_files( + adapter_id: str, revision: Optional[str], extension: str +) -> List[str]: + """Guess weight files from the cached revision snapshot directory""" + d = _get_cached_revision_directory(adapter_id, revision) + if not d: + return [] + filenames = _adapter_weight_files_from_dir(d, extension) + return filenames + + def _cached_weight_files( model_id: str, revision: Optional[str], extension: str ) -> List[str]: @@ -60,6 +71,21 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: return filenames +def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_files_from_dir, that's also what is done there + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "training" not in f + ] + return filenames + + def _get_cached_revision_directory( model_id: str, revision: Optional[str] ) -> Optional[Path]: