From d5f21d57d1222e76628891778c19ac2d1b48e68d Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Jun 2024 14:35:59 +0000 Subject: [PATCH] fix: prefer adapter_data and refactors --- .../custom_modeling/flash_llama_modeling.py | 64 ++----------------- .../models/flash_causal_lm.py | 17 ----- .../models/flash_llama.py | 4 +- 3 files changed, 5 insertions(+), 80 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3e9cedcb..42a28cc6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -117,8 +117,6 @@ class FlashLlamaAttention(torch.nn.Module): prefix: str, config, weights, - lora_weights, - lora_configs, ): super().__init__() self.num_heads = config.num_attention_heads @@ -151,40 +149,6 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index - self.adapter_weights = {} - adapter_names = list(lora_weights.keys()) - - self.n_loras = len(adapter_names) - self.pre_multiplied_lora_matrices = torch.empty( - (self.n_loras, 2, self.hidden_size, self.hidden_size), - device=weights.device, - dtype=weights.dtype, - ) - - self.key_to_index = {} - - lora_prefix = f"base_model.model.model.layers.{index}.self_attn" - for adapter_index, adapter_name in enumerate(adapter_names): - self.lora_alpha = lora_configs[adapter_name].lora_alpha - self.lora_r = lora_configs[adapter_name].r - self.lora_scale = self.lora_alpha / self.lora_r - self.key_to_index[adapter_name] = adapter_index - adapter_weights = lora_weights[adapter_name] - for target_index, target in enumerate(["q", "v"]): - adapter_weight_a = adapter_weights.get_tensor( - f"{lora_prefix}.{target}_proj.lora_A.weight" - ) - adapter_weight_b = adapter_weights.get_tensor( - f"{lora_prefix}.{target}_proj.lora_B.weight" - ) - pre_multiplied_lora_matrix = torch.matmul( - adapter_weight_a.T * self.lora_scale, - adapter_weight_b.T, - ).contiguous() - - self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = ( - pre_multiplied_lora_matrix - ) o_proj = TensorParallelRowLinear.load( config, @@ -216,8 +180,6 @@ class FlashLlamaAttention(torch.nn.Module): slots, input_lengths, max_s, - batch_lora_adapter_mask, - lora_indices, adapter_data, ): qkv = self.query_key_value(hidden_states, adapter_data) @@ -355,15 +317,13 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): - def __init__(self, index, prefix, config, weights, lora_weights, lora_configs): + def __init__(self, index, prefix, config, weights): super().__init__() self.self_attn = FlashLlamaAttention( index=index, prefix=f"{prefix}.self_attn", config=config, weights=weights, - lora_weights=lora_weights, - lora_configs=lora_configs, ) self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index @@ -390,8 +350,6 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, - batch_lora_adapter_mask, - lora_indices, adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -407,8 +365,6 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, - batch_lora_adapter_mask, - lora_indices, adapter_data, ) @@ -423,7 +379,7 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): - def __init__(self, prefix, config, weights, lora_weights, lora_configs): + def __init__(self, prefix, config, weights): super().__init__() process_group = weights.process_group @@ -440,8 +396,6 @@ class FlashLlamaModel(torch.nn.Module): ), config=config, weights=weights, - lora_weights=lora_weights, - lora_configs=lora_configs, ) for layer_id in range(config.num_hidden_layers) ] @@ -470,8 +424,6 @@ class FlashLlamaModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], - batch_lora_adapter_mask: Optional[List[str]], - lora_indices: Optional[torch.Tensor], adapter_data, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -495,8 +447,6 @@ class FlashLlamaModel(torch.nn.Module): slots, input_lengths, max_s, - batch_lora_adapter_mask, - lora_indices, adapter_data, ) @@ -506,7 +456,7 @@ class FlashLlamaModel(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, lora_weights, lora_configs): + def __init__(self, prefix, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -515,9 +465,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ), weights=weights, ) - self.model = FlashLlamaModel( - prefix, config, weights, lora_weights, lora_configs - ) + self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: @@ -544,8 +492,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, - batch_lora_adapter_mask: Optional[List[str]] = None, - lora_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) @@ -560,8 +506,6 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, - batch_lora_adapter_mask=batch_lora_adapter_mask, - lora_indices=lora_indices, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f4534bd0..61b31691 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -877,8 +877,6 @@ class FlashCausalLM(Model): graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph - batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device) - lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device) torch.cuda.synchronize() # Run once outside to warmup self.model.forward( @@ -892,8 +890,6 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, - batch_lora_adapter_mask=batch_lora_adapter_mask, - lora_indices=lora_indices, ) torch.cuda.synchronize() @@ -909,8 +905,6 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, - batch_lora_adapter_mask=batch_lora_adapter_mask, - lora_indices=lora_indices, ) self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["speculative_logits"] = speculative_logits @@ -1038,10 +1032,6 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - batch_lora_adapter_mask = torch.zeros( - seqlen, dtype=torch.bool, device=self.device - ) - lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1057,8 +1047,6 @@ class FlashCausalLM(Model): max_s=seqlen, lm_head_indices=None, prefill_cache_indices=None, - batch_lora_adapter_mask=batch_lora_adapter_mask, - lora_indices=lora_indices, ) def forward( @@ -1129,9 +1117,6 @@ class FlashCausalLM(Model): else: cuda_graph = None - batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device) - lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device) - if cu_seqlen_prefill is not None or cuda_graph is None: logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -1144,8 +1129,6 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, - batch_lora_adapter_mask=batch_lora_adapter_mask, - lora_indices=lora_indices, adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index fbe953e9..1266f6de 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -99,9 +99,7 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) prefix = "" - model = FlashLlamaForCausalLM( - prefix, config, weights, lora_weights, lora_configs - ) + model = FlashLlamaForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model_id=model_id,