diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703..e00437c7 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,7 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], - adapter_id: None, + adapter_index: None, }) .collect(); diff --git a/proto/generate.proto b/proto/generate.proto index cffaa719..366a5418 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -107,8 +107,8 @@ message Request { bool prefill_logprobs = 6; /// Return most likely n tokens uint32 top_n_tokens = 7; - /// LORA adapter id - optional string adapter_id = 8; + /// LORA adapter index + optional uint32 adapter_index = 8; } message Batch { diff --git a/router/client/src/v2/client.rs b/router/client/src/v2/client.rs index 3e5d9d3b..ff1a70eb 100644 --- a/router/client/src/v2/client.rs +++ b/router/client/src/v2/client.rs @@ -154,7 +154,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, - adapter_id: None, + adapter_index: None, }); n_tokens += max_input_length; diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 9265b79a..e284b251 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -290,7 +290,7 @@ impl State { entry.request.stopping_parameters.clone(), )), top_n_tokens: entry.request.top_n_tokens, - adapter_id: entry.request.adapter_id.clone(), + adapter_index: entry.request.adapter_index, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -430,7 +430,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, - adapter_id: None, + adapter_index: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/lib.rs b/router/src/lib.rs index 08d57873..c99e8281 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters { /// Lora adapter id #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] - pub adapter_id: Option, + pub adapter_index: Option, } fn default_max_new_tokens() -> Option { @@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters { seed: None, top_n_tokens: None, grammar: None, - adapter_id: None, + adapter_index: None, } } diff --git a/router/src/validation.rs b/router/src/validation.rs index e2bf5a5d..6f776870 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -202,7 +202,7 @@ impl Validation { decoder_input_details, top_n_tokens, grammar, - adapter_id, + adapter_index, .. } = request.parameters; @@ -384,7 +384,7 @@ impl Validation { parameters, stopping_parameters, top_n_tokens, - adapter_id, + adapter_index, }) } @@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, - pub adapter_id: Option, + pub adapter_index: Option, } #[derive(Error, Debug)] diff --git a/server/text_generation_server/adapters/__init__.py b/server/text_generation_server/adapters/__init__.py new file mode 100644 index 00000000..0e6f6e45 --- /dev/null +++ b/server/text_generation_server/adapters/__init__.py @@ -0,0 +1,31 @@ +import json +from pathlib import Path +from typing import Dict, Optional + +from text_generation_server.adapters.config import AdapterConfig +from text_generation_server.adapters.lora import LoraConfig +from text_generation_server.adapters.weights import ( + AdapterBatchData, + AdapterBatchMetadata, +) + + +def load_adapter_config( + config_path: Optional[Path], + adapter_config_path: Optional[Path], + api_token: str, +) -> AdapterConfig: + if adapter_config_path is not None and adapter_config_path.exists(): + return LoraConfig.load(str(adapter_config_path.parent), api_token) + + raise ValueError( + f"No valid adapter config file found: " + f"tried {adapter_config_path} and {config_path}" + ) + + +__all__ = [ + "AdapterBatchData", + "AdapterBatchMetadata", + "load_adapter_config", +] diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py new file mode 100644 index 00000000..653c7bc8 --- /dev/null +++ b/server/text_generation_server/adapters/config.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple + +import torch + +from text_generation_server.adapters.weights import AdapterWeights + +if TYPE_CHECKING: + from text_generation_server.models.model import Model + + +ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]] + + +@dataclass +class AdapterConfig(ABC): + base_model_name_or_path: str + + @abstractmethod + def map_weights_for_model( + self, + adapter_weights: Dict, + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + pass + + @abstractmethod + def load_batched_adapter_weights( + self, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + dynamic: bool, + ) -> Optional[AdapterWeights]: + pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py new file mode 100644 index 00000000..458a22e1 --- /dev/null +++ b/server/text_generation_server/adapters/lora.py @@ -0,0 +1,430 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +from peft import LoraConfig as _LoraConfig +from torch.distributed import ProcessGroup + +from text_generation_server.adapters.config import AdapterConfig, ModuleMap + +LORA = "lora" +from text_generation_server.adapters.weights import ( + AdapterBatchMetadata, + AdapterWeights, + BatchAdapterWeights, +) +from text_generation_server.utils.sgmv import ( + BGMV_MAX_RANK, + MAX_RANK_CUSTOM, + get_tmp_tensors, + orient_for_rank, + pad_rank, + use_cutlass_shrink, +) + +if TYPE_CHECKING: + from text_generation_server.models.model import Model + +EMPTY_TENSOR = torch.tensor([]) + + +@dataclass +class LoraConfig(AdapterConfig): + r: int + target_modules: Optional[Union[List[str], str]] + fan_in_fan_out: bool + lora_alpha: int + use_rslora: bool + + def map_weights_for_model( + self, + adapter_weights: Dict, + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + adapter_weight_names = set() + module_map = {} + for weight_name in weight_names: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: + continue + + module_map[weight_name] = { + "lora_A": (adapter_weights[lora_a_name], lora_a_name), + "lora_B": (adapter_weights[lora_b_name], lora_b_name), + } + adapter_weight_names.add(lora_a_name) + adapter_weight_names.add(lora_b_name) + return module_map, adapter_weight_names + + def load_batched_adapter_weights( + self, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + dynamic: bool, + ) -> Optional[AdapterWeights]: + return LoraWeights.load( + self, + model, + module_map, + layer_type, + unused_weight_names, + ) + + @classmethod + def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": + hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) + return cls( + base_model_name_or_path=hf_config.base_model_name_or_path, + r=hf_config.r, + target_modules=hf_config.target_modules, + fan_in_fan_out=hf_config.fan_in_fan_out, + lora_alpha=hf_config.lora_alpha, + use_rslora=( + hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False + ), + ) + + +class LoraWeights(AdapterWeights): + """LoRA weights for a single adapter merged across all layers.""" + + def __init__( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + adapter_config: LoraConfig, + ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + + self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._is_transposed = False + + # [num_layers, hidden_size, r] + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + self._weights_a = torch.stack(weights_a) + + # [num_layers, r, hidden_size] + self._weights_b = torch.stack(weights_b) + + self.adapter_config = adapter_config + + @property + def weights_a(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_b + + @property + def weights_a_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_b + + def _transpose_weights(self): + if self._use_cutlass_shrink: + # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation + self._weights_a = self._weights_a.transpose(1, 2).contiguous() + self._weights_b = self._weights_b.transpose(1, 2).contiguous() + self._is_transposed = not self._is_transposed + + @classmethod + def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: + return [BatchLoraWeights] + + @classmethod + def load( + cls, + config: LoraConfig, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + ) -> Optional[AdapterWeights]: + nlayers = model.get_num_layers_for_type(layer_type) + lora_a_list = [None] * nlayers + lora_b_list = [None] * nlayers + + for layer_id in range(nlayers): + key = (layer_id, layer_type) + weight_name, layer = model.target_to_layer[key] + base_weight = layer.base_layer.linear.weight + base_device = base_weight.device + + if weight_name not in module_map: + # There is no LoRA weight for this layer type in the adapter + return None + + lora_a, lora_a_name = module_map[weight_name]["lora_A"] + lora_a = lora_a.to(base_device, model.dtype) + + lora_b, lora_b_name = module_map[weight_name]["lora_B"] + lora_b = lora_b.to(base_device, model.dtype) + + scale = get_scaling_factor( + config.lora_alpha, + config.r, + uses_rslora=config.use_rslora, + ) + + unused_weight_names.discard(lora_a_name) + unused_weight_names.discard(lora_b_name) + + # Merge scaling factor into lora_b due to associativity of matrix multiplication: + # (A * B) * C = A * (B * C) + lora_a_list[layer_id] = lora_a.transpose(0, 1) + lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale + + # pad lora ranks to be compatible with sgmv + lora_a_list = [ + pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list + ] + lora_b_list = [ + pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list + ] + + if lora_a_list: + # update rank if it was padded + padded_rank = lora_a_list[0].size(1) + config.r = padded_rank + + return LoraWeights( + *model.shard_lora_weights(lora_a_list, lora_b_list, layer_type), + config, + ) + + +@dataclass +class RankSegments: + rank: int + + lora_a_ptr: torch.Tensor + lora_b_ptr: torch.Tensor + + # prefill (sgmv) + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor + segment_starts: torch.Tensor + segment_ends: torch.Tensor + + # decode (bgmv) + indices: torch.Tensor + + +@dataclass +class BatchLoraWeights(BatchAdapterWeights): + lora_a: Dict[int, torch.Tensor] + lora_b: Dict[int, torch.Tensor] + adapter_index_configs: Dict[int, LoraConfig] + rank_data: Dict[int, RankSegments] + use_sgmv: bool + + def has_adapter(self, adapter_index: int) -> bool: + return adapter_index in self.adapter_index_configs + + def can_vectorize(self, pg: ProcessGroup) -> bool: + return all( + rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + for rank_data in self.rank_data.values() + ) + + @classmethod + def key(cls) -> str: + return LORA + + @classmethod + def load( + self, + adapter_weights: Dict[int, AdapterWeights], + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Optional["BatchLoraWeights"]: + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} + adapter_weights = { + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) + } + if not adapter_weights: + return None + + first_weights = list(adapter_weights.values())[0] + device = first_weights.weights_a.device + segment_indices = meta.segment_indices + + lora_a = { + idx: adapter_weights[idx].weights_a + for idx in segment_indices + if idx in adapter_weights + } + lora_b = { + idx: adapter_weights[idx].weights_b + for idx in segment_indices + if idx in adapter_weights + } + + max_rank = max( + adapter_weights[idx].lora_a_r + for idx in segment_indices + if idx in adapter_weights + ) + + if prefill or max_rank > BGMV_MAX_RANK: + use_sgmv = True + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a.data_ptr() + if idx in adapter_weights + else EMPTY_TENSOR.data_ptr() + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b.data_ptr() + if idx in adapter_weights + else EMPTY_TENSOR.data_ptr() + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + else: + use_sgmv = False + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a_t.data_ptr() + if idx in adapter_weights + else EMPTY_TENSOR.data_ptr() + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b_t.data_ptr() + if idx in adapter_weights + else EMPTY_TENSOR.data_ptr() + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} + + rank_indices = defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx not in adapter_weights: + continue + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + + if prefill_head_indices is not None: + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + for head_index in prefill_head_indices: + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + if head_index < meta.adapter_segments[j]: + prefill_head_segment_ends[-1] += 1 + else: + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) + j += 1 + + rank_data = {} + for rank, indices in rank_indices.items(): + tmp_shrink = None + tmp_expand = None + segment_starts = None + segment_ends = None + batch_indices = None + + if use_sgmv: + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors( + lora_a_ptr_indices.size(0), rank, device + ) + segment_starts = meta.adapter_segments[indices] + segment_ends = meta.adapter_segments[[i + 1 for i in indices]] + if prefill_head_indices is not None: + for i, segment_index in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_index] + segment_ends[i] = prefill_head_segment_ends[segment_index] + else: + rank_indices = set(indices) + batch_indices = [ + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() + ] + batch_indices = [ + idx if idx in rank_indices else -1 for idx in batch_indices + ] + batch_indices = torch.tensor( + batch_indices, dtype=torch.int64, device=device + ) + + rank_data[rank] = RankSegments( + rank=rank, + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=lora_a_ptr[indices], + lora_b_ptr=lora_b_ptr[indices], + segment_starts=segment_starts, + segment_ends=segment_ends, + indices=batch_indices, + ) + + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + use_sgmv=use_sgmv, + ) + + +def get_scaling_factor( + lora_alpha: int, + r: int, + uses_rslora: bool = False, +) -> float: + """Computes the scaling factor for the lora weights.""" + if uses_rslora: + return lora_alpha / (r**0.5) + return lora_alpha / r + + +def _convert_lora(v: AdapterWeights) -> AdapterWeights: + if hasattr(v, "lora_weights"): + return v.lora_weights + return v diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py new file mode 100644 index 00000000..2ed08df5 --- /dev/null +++ b/server/text_generation_server/adapters/weights.py @@ -0,0 +1,159 @@ +############# +from abc import ABC, abstractclassmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Type + +import torch + + +LORA = "lora" +LM_HEAD = "lm_head" + + +@dataclass +class AdapterBatchMetadata: + # [batch_size] + adapter_indices: torch.Tensor + + # [num_adapters] + adapter_set: Set[int] + + # [num_segments + 1] + adapter_segments: torch.Tensor + + # [num_segments] + # maps from segment index to adapter index, i.e.: + # segment_indices[s] == adapter_indices[i] + segment_indices: List[int] + + +class AdapterWeights(ABC): + @abstractclassmethod + def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]: + pass + + @property + def speculative_tokens(self) -> int: + return 0 + + +class BatchAdapterWeights(ABC): + @abstractclassmethod + def has_adapter(self, adapter_index: int) -> bool: + pass + + @abstractclassmethod + def key(cls) -> str: + pass + + @abstractclassmethod + def load( + cls, + adapter_weights: Dict[int, AdapterWeights], + meta: "AdapterBatchMetadata", + prefill: bool, + prefill_head_indices: torch.Tensor, + ) -> Optional["BatchAdapterWeights"]: + pass + + +class LayerAdapterWeights: + """Adapter weights that apply to a particular layer.""" + + def __init__(self): + self.adapter_weights: Dict[int, AdapterWeights] = {} + + def add_adapter(self, adapter_idx: int, weights: AdapterWeights): + self.adapter_weights[adapter_idx] = weights + + def remove_adapter(self, adapter_idx: int): + if adapter_idx not in self.adapter_weights: + return + del self.adapter_weights[adapter_idx] + + @property + def max_speculative_tokens(self) -> int: + return max( + adapter_weights.speculative_tokens + for adapter_weights in self.adapter_weights.values() + ) + + def is_empty(self) -> bool: + return len(self.adapter_weights) == 0 + + def get_data( + self, + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Dict[str, BatchAdapterWeights]: + # bucket adapters by batch class + adapter_batch_types: Dict[ + Type[BatchAdapterWeights], Dict[int, AdapterWeights] + ] = defaultdict(dict) + for adapter_index, adapter_weights in self.adapter_weights.items(): + for batch_type in adapter_weights.get_batch_types(): + adapter_batch_types[batch_type][adapter_index] = adapter_weights + + batch_data = {} + for batch_type, adapter_weights in adapter_batch_types.items(): + batched_weights = batch_type.load( + adapter_weights, meta, prefill, prefill_head_indices + ) + if batched_weights is not None: + batch_data[batch_type.key()] = batched_weights + return batch_data + + +@dataclass +class AdapterBatchData: + meta: AdapterBatchMetadata + + # layer type -> adapter type -> batch weight data + data: Dict[str, Dict[str, BatchAdapterWeights]] + + prefill: bool + + @staticmethod + def from_meta( + meta: AdapterBatchMetadata, + weights: Dict[str, LayerAdapterWeights], + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> "AdapterBatchData": + data = {} + for k, v in weights.items(): + if v.is_empty(): + continue + data[k] = v.get_data( + meta, prefill, prefill_head_indices if k == LM_HEAD else None + ) + return AdapterBatchData(meta=meta, data=data, prefill=prefill) + + def ranks(self) -> Set[int]: + # TODO(travis): refactor to be less coupled to lora implementation + ranks = set() + for layer_data in self.data.values(): + lora_data = layer_data.get(LORA) + if lora_data is None: + continue + + for rank_data in lora_data.rank_data.values(): + ranks.add(rank_data.rank) + + return ranks + + def layer_names(self) -> Set[str]: + return set(self.data.keys()) + + def adapter_keys(self) -> Set[str]: + adapter_keys = set() + for layer_data in self.data.values(): + adapter_keys.update(layer_data.keys()) + return adapter_keys + + @property + def max_rank(self) -> int: + ranks = self.ranks() + return max(ranks) if len(ranks) > 0 else 0 diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index c29dd092..32c8d121 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d + +from text_generation_server.layers.lora import ( + LoraLinear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py new file mode 100644 index 00000000..30070287 --- /dev/null +++ b/server/text_generation_server/layers/lora.py @@ -0,0 +1,244 @@ +import math +import os +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.distributed +from accelerate import init_empty_weights +from torch import nn +from torch.nn import functional as F + +from text_generation_server.utils.sgmv import ( + add_lora_a_bgmv, + add_lora_b_bgmv, + has_sgmv, + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + orient_for_rank, +) + +LORA = "lora" +MEDUSA = "medusa" + +if TYPE_CHECKING: + from text_generation_server.adapters import AdapterBatchData + from text_generation_server.adapters.lora import BatchLoraWeights + + +class LoraLinear(nn.Module): + def __init__(self, base_layer, layer_id, process_group): + super().__init__() + self.base_layer = base_layer + self.layer_id = layer_id + self.process_group = process_group + + def forward_layer_type( + self, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + ) -> torch.Tensor: + data = adapter_data.data.get(layer_type) + data: Optional["BatchLoraWeights"] = ( + data.get(LORA) if data is not None else None + ) + + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if data.use_sgmv: + # Use SGMV for prefill + if lora_a_ptr is not None and lora_b_ptr is not None: + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + if lora_a_ptr is not None and lora_b_ptr is not None: + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) + layer_result = self.forward_lora( + input, data, adapter_index, adapter_mask + ) + result[:, start_idx:end_idx] += layer_result + + return result + + def forward_lora( + self, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + ) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][self.layer_id, :, :] + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = input @ lora_a + if self.process_group.size() > 1: + a_out = self.collect_lora_a(a_out) + + result = (a_out @ lora_b) * adapter_mask + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Implemented in subclasses") + + +class TensorParallelMultiAdapterLinear(LoraLinear): + def __init__(self, base_layer, layer_id, layer_names, sizes, process_group): + super().__init__(base_layer, layer_id, process_group) + self.layer_names = layer_names + self.sizes = sizes + + @classmethod + def load(cls, base_layer, layer_id, layer_names, sizes, process_group): + return TensorParallelMultiAdapterLinear( + base_layer, layer_id, layer_names, sizes, process_group + ) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + # handle models like Bloom that have inputs of shape + # (batch_size, sequence_length, hidden_size) + # we need to reshape them to (batch_size * sequence_length, hidden_size) + # for the LoRA computation, then reshape back + prev_shape = result.shape + is_3d = len(input.shape) >= 3 + if is_3d: + input = input.reshape(-1, input.shape[-1]) + result = result.reshape(-1, result.shape[-1]) + + offset = 0 + for i, layer_name in enumerate(self.layer_names): + start_idx = offset // self.process_group.size() + + if self.sizes is not None: + offset += self.sizes[i] + end_idx = offset // self.process_group.size() + else: + end_idx = result.shape[1] + + result = self.forward_layer_type( + result, input, adapter_data, layer_name, start_idx, end_idx + ) + + if is_3d: + result = result.reshape(prev_shape) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise. + # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-gather for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + gathered_tensors = [ + torch.empty_like(a_out) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(gathered_tensors, a_out) + return torch.cat(gathered_tensors, dim=1) + + +class TensorParallelAdapterRowLinear(LoraLinear): + def __init__(self, base_layer, layer_id, layer_name, process_group): + super().__init__(base_layer, layer_id, process_group) + self.layer_name = layer_name + + @classmethod + def load(cls, base_layer, layer_id, layer_name, process_group): + return cls(base_layer, layer_id, layer_name, process_group) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = result.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + self.forward_layer_type( + result, input, adapter_data, self.layer_name, start_idx, end_idx + ) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + torch.distributed.all_reduce(a_out, group=self.process_group) + return a_out diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index fac54480..3e9cedcb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -50,6 +52,16 @@ if SYSTEM == "rocm": except Exception as e: raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") +# Constants +Q_PROJ = "q_proj" +K_PROJ = "k_proj" +V_PROJ = "v_proj" +O_PROJ = "o_proj" + +GATE_PROJ = "gate_proj" +UP_PROJ = "up_proj" +DOWN_PROJ = "down_proj" + def load_attention(config, prefix, weights): # Only defined in granite. @@ -57,7 +69,7 @@ def load_attention(config, prefix, weights): # if specific model type, load the correct attention if config.model_type == "phi3": - return TensorParallelColumnLinear.load_qkv( + base_layer = TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.qkv_proj", weights=weights, @@ -66,7 +78,7 @@ def load_attention(config, prefix, weights): num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( + base_layer = TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.W_pack", weights=weights, @@ -76,7 +88,7 @@ def load_attention(config, prefix, weights): ) # otherwise, load the default attention based on the number of heads - return TensorParallelColumnLinear.load_multi( + base_layer = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, @@ -84,6 +96,19 @@ def load_attention(config, prefix, weights): bias=bias, ) + head_size = config.hidden_size // config.num_attention_heads + return TensorParallelMultiAdapterLinear.load( + base_layer, + layer_id, + [Q_PROJ, K_PROJ, V_PROJ], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + class FlashLlamaAttention(torch.nn.Module): def __init__( @@ -124,7 +149,7 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) self.index = index self.adapter_weights = {} adapter_names = list(lora_weights.keys()) @@ -161,12 +186,20 @@ class FlashLlamaAttention(torch.nn.Module): pre_multiplied_lora_matrix ) - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + O_PROJ, + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -185,8 +218,9 @@ class FlashLlamaAttention(torch.nn.Module): max_s, batch_lora_adapter_mask, lora_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -197,32 +231,6 @@ class FlashLlamaAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - batch_size = query.size(0) - - # hidden states without LoRA - hs_wl = hidden_states[lora_indices == -1] - - adapted_query_states = [hs_wl] - adapted_value_states = [hs_wl] - - for ind in range(self.n_loras): - mask = lora_indices == ind - hs_sub = hidden_states[mask] - mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0]) - mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1]) - adapted_query_states.append(mat_q) - adapted_value_states.append(mat_v) - - query_adapted = torch.cat(adapted_query_states, dim=0).view( - batch_size, self.num_heads, self.head_size - ) - value_adapted = torch.cat(adapted_value_states, dim=0).view( - batch_size, self.num_key_value_heads, self.head_size - ) - - query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask] - kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask] - self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) @@ -260,7 +268,7 @@ class FlashLlamaAttention(torch.nn.Module): class LlamaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -278,26 +286,46 @@ class LlamaMLP(nn.Module): # Fuse gate and up proj bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": - self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( + gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, bias=bias, ) else: - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=bias, ) - self.down_proj = TensorParallelRowLinear.load( + + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + [GATE_PROJ, UP_PROJ], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=bias, ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + DOWN_PROJ, + process_group=weights.process_group, + ) + self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -337,7 +365,9 @@ class FlashLlamaLayer(nn.Module): lora_weights=lora_weights, lora_configs=lora_configs, ) - self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, index=index + ) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -362,6 +392,7 @@ class FlashLlamaLayer(nn.Module): max_s, batch_lora_adapter_mask, lora_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -378,6 +409,7 @@ class FlashLlamaLayer(nn.Module): max_s, batch_lora_adapter_mask, lora_indices, + adapter_data, ) # faster post attention rms norm @@ -440,6 +472,7 @@ class FlashLlamaModel(torch.nn.Module): prefill_cache_indices: Optional[torch.Tensor], batch_lora_adapter_mask: Optional[List[str]], lora_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -464,6 +497,7 @@ class FlashLlamaModel(torch.nn.Module): max_s, batch_lora_adapter_mask, lora_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -512,6 +546,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, batch_lora_adapter_mask: Optional[List[str]] = None, lora_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -527,6 +562,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): prefill_cache_indices=prefill_cache_indices, batch_lora_adapter_mask=batch_lora_adapter_mask, lora_indices=lora_indices, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0131c0e0..f4534bd0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -13,6 +13,7 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Iterable, Optional, Tuple, List, Type, Dict +from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM @@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS import text_generation_server.models.globals as tgi_globals from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( empty_cache, @@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch): top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor + # Adapter metadata for each request + adapter_meta: AdapterBatchMetadata + # Number of blocks in this batch num_blocks: int # Maximum number of blocks @@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] + adapter_indices_list = [] + adapter_set = set() + # Cumulative length cumulative_length = 0 cumulative_max_length = 0 @@ -225,6 +233,9 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) + adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) + adapter_set.add(r.adapter_index) + # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() @@ -296,6 +307,10 @@ class FlashCausalLMBatch(Batch): max_length, input_length + max_new_tokens + speculative_length ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) @@ -339,6 +354,11 @@ class FlashCausalLMBatch(Batch): input_lengths, dtype=torch.int32, device=device ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -393,6 +413,12 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), speculative_ids=None, ) @@ -443,6 +469,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] + adapter_set = set() num_blocks = 0 max_blocks = 0 @@ -471,6 +498,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) + adapter_set.add(self.requests[idx].adapter_index) + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -498,6 +527,7 @@ class FlashCausalLMBatch(Batch): # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] @@ -513,6 +543,11 @@ class FlashCausalLMBatch(Batch): # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -543,6 +578,12 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) @classmethod @@ -596,6 +637,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_set = set() + adapter_segment_builder = SegmentConcatBuilder() start_slots = [] block_tables = [] @@ -613,6 +662,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 cumulative_slots = 0 + cumulative_adapter_indices_size = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -637,6 +687,18 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -680,6 +742,8 @@ class FlashCausalLMBatch(Batch): else None ) + _adapter_segments, _adapter_segment_indices = adapter_segment_builder.build() + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -719,6 +783,7 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, num_layers: int, @@ -738,6 +803,7 @@ class FlashCausalLM(Model): self.kv_cache = [] super(FlashCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, @@ -996,7 +1062,7 @@ class FlashCausalLM(Model): ) def forward( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: @@ -1066,13 +1132,6 @@ class FlashCausalLM(Model): batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device) lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device) - for i, r in enumerate(batch.requests): - if r.adapter_id: - lora_index = self.model.get_lora_index(r.adapter_id) - input_length = batch.input_lengths[i] - lora_indices[i : i + input_length] = lora_index - batch_lora_adapter_mask[i] = True - if cu_seqlen_prefill is not None or cuda_graph is None: logits, speculative_logits = self.model.forward( input_ids=input_ids, @@ -1087,6 +1146,7 @@ class FlashCausalLM(Model): lm_head_indices=lm_head_indices, batch_lora_adapter_mask=batch_lora_adapter_mask, lora_indices=lora_indices, + adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -1123,7 +1183,34 @@ class FlashCausalLM(Model): prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - out, speculative_logits = self.forward(batch) + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) + + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + + out, speculative_logits = self.forward(batch, adapter_data) if prefill: next_token_logits = ( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index c5d3ecac..fbe953e9 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -4,7 +4,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional +from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( @@ -22,6 +22,30 @@ tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.lora import LoraConfig +Q_PROJ = "q_proj" +K_PROJ = "k_proj" +V_PROJ = "v_proj" +O_PROJ = "o_proj" + +GATE_PROJ = "gate_proj" +UP_PROJ = "up_proj" +DOWN_PROJ = "down_proj" + +LM_HEAD = "lm_head" + + +# TODO(travis): re-enable LM_HEAD after resolving issues with outputs +ADAPTER_LAYERS = [ + Q_PROJ, + K_PROJ, + V_PROJ, + O_PROJ, + GATE_PROJ, + UP_PROJ, + DOWN_PROJ, +] # LM_HEAD +ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + class FlashLlama(FlashCausalLM): def __init__( @@ -80,6 +104,7 @@ class FlashLlama(FlashCausalLM): ) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -90,3 +115,59 @@ class FlashLlama(FlashCausalLM): rank=rank, world_size=world_size, ) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + for i, layer in enumerate(self.model.model.layers): + layer_weights[(i, Q_PROJ)] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, K_PROJ)] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, V_PROJ)] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, O_PROJ)] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, GATE_PROJ)] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, UP_PROJ)] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, DOWN_PROJ)] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [Q_PROJ, V_PROJ] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 4f35b0aa..a5ef7908 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -2,12 +2,50 @@ import inspect import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type +from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict +from collections import defaultdict from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse +from text_generation_server.adapters.weights import LayerAdapterWeights +from text_generation_server.utils.adapter import ( + load_and_merge_adapters, + AdapterParameters, + AdapterSource, +) +from loguru import logger + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + + +def get_start_stop_idxs_for_rank(offset, size, rank, world_size): + block_size = size // world_size + start = offset + rank * block_size + stop = offset + (rank + 1) * block_size + return start, stop + + +def shard_on_dim( + t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup +): + world_size = process_group.size() + rank = process_group.rank() + + size = t.shape[dim] + start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) + + if dim == 0: + tensor = t[start:stop] + elif dim == 1: + tensor = t[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + + return tensor + B = TypeVar("B", bound=Batch) @@ -15,6 +53,7 @@ B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, requires_padding: bool, @@ -25,6 +64,7 @@ class Model(ABC): sliding_window: Optional[int] = None, speculate: Optional[int] = None, ): + self.model_id = model_id self.model = model.eval() self.tokenizer = tokenizer @@ -42,6 +82,12 @@ class Model(ABC): self.world_size = world_size self.sliding_window = sliding_window if sliding_window != -1 else None + self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( + LayerAdapterWeights + ) + self.target_to_layer = self.adapter_target_to_layer() + self.loaded_adapters = set() + if speculate is None: speculate = get_speculate() self.speculate = speculate @@ -119,3 +165,156 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) + + @property + def supports_adapter_loading(self) -> bool: + return False + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + return {} + + @property + def adapter_layers(self) -> List[str]: + return [] + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 0 + + def is_row_parallel(self, layer_type: str) -> bool: + return False + + @property + def max_speculative_tokens(self) -> int: + return max( + [ + weights.max_speculative_tokens + for weights in self.layer_to_adapter_weights.values() + ], + default=0, + ) + + def load_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + api_token: str, + dynamic: bool = True, + ): + """Loads adapter weights from disk / host memory on the GPU. + + adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded + into model. Otherwise, the adapter weights are applied during the forward + pass and stored separately from the base model parameters. + """ + if adapter_index in self.loaded_adapters: + # Adapter already loaded + return + + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if dynamic and not self.dynamic_adapter_loading_enabled: + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) + + logger.info( + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + ) + weight_names = tuple([v[0] for v in self.target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + self.model_id, + adapter_parameters, + adapter_source, + adapter_index, + weight_names, + api_token, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + for layer_name in self.adapter_layers: + adapter_weights = adapter_config.load_batched_adapter_weights( + self, + module_map, + layer_name, + unused_weight_names, + dynamic, + ) + + if adapter_weights is None: + continue + + layer_weights = self.layer_to_adapter_weights[layer_name] + layer_weights.add_adapter(adapter_index, adapter_weights) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + self.loaded_adapters.add(adapter_index) + + def shard_lora_weights( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + layer_type: str, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + split_dim = 0 if self.is_row_parallel(layer_type) else 1 + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=self.process_group) + for w in weights_a + ] + + # [r, hidden_size] + weights_b = [ + shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b + ] + + return weights_a, weights_b + + def offload_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + ): + """Offloads the adapter weights from GPU to CPU or disk.""" + if adapter_index not in self.loaded_adapters: + # Adapter already offloaded + return + + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if not self.dynamic_adapter_loading_enabled: + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) + + for layer_name in self.adapter_layers: + if layer_name in self.layer_to_adapter_weights: + self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) + + self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 9a5e9226..4a059776 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -30,6 +30,9 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id +from text_generation_server.utils.adapter import ( + AdapterParameters, +) class SignalHandler: @@ -235,6 +238,30 @@ def serve( trust_remote_code, max_input_tokens, ) + + # TODO: avoid hacky hardcoded adapter id + adapter_parameters = AdapterParameters( + adapter_ids=lora_adapter_ids, + weights=[ + # TODO: fill with actual weights + torch.tensor([1.0], dtype=torch.float32) + ], + merge_strategy=0, + density=0.0, + majority_sign_method=0, + ) + adapter_source = None + adapter_index = None + api_token = None + + model.load_adapter( + adapter_parameters, + adapter_source, + adapter_index, + api_token, + False, + ) + except Exception: logger.exception("Error when initializing model") raise diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py new file mode 100644 index 00000000..53e92180 --- /dev/null +++ b/server/text_generation_server/utils/adapter.py @@ -0,0 +1,196 @@ +import warnings +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, Set, Tuple + +from safetensors.torch import load_file +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer + +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils.merges.strategies import merge_adapters + +from text_generation_server.utils import hub +from text_generation_server.adapters.lora import LoraConfig + + +if TYPE_CHECKING: + from text_generation_server.adapters.config import AdapterConfig, ModuleMap + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + + +class AdapterParameters: + def __init__( + self, adapter_ids, weights, merge_strategy, density, majority_sign_method + ): + self.adapter_ids = adapter_ids + self.weights = weights + self.merge_strategy = merge_strategy + self.density = density + self.majority_sign_method = majority_sign_method + + +class AdapterSource: + def __init__(self, adapter_id: str, model_id: str, revision: str): + self.adapter_id = adapter_id + self.model_id = model_id + self.revision = revision + + +def load_and_merge_adapters( + model_id: str, + adapter_parameters: AdapterParameters, + adapter_source: str, + adapter_index: int, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + if len(adapter_parameters.adapter_ids) == 1: + return load_module_map( + model_id, + adapter_parameters.adapter_ids[0], + adapter_source, + weight_names, + api_token, + trust_remote_code, + ) + + adapter_params = AdapterParametersContainer( + adapter_parameters, adapter_source, adapter_index + ) + return _load_and_merge( + model_id, adapter_params, weight_names, api_token, trust_remote_code + ) + + +class AdapterParametersContainer: + def __init__(self, adapter_parameters, adapter_source, adapter_index): + self.adapter_parameters = adapter_parameters + self.adapter_source = adapter_source + self.adapter_index = adapter_index + + def __hash__(self) -> int: + return self.adapter_index + + +@lru_cache(maxsize=32) +def _load_and_merge( + model_id: str, + adapter_params: AdapterParametersContainer, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + params = adapter_params.adapter_parameters + + adapters_to_merge = [] + merged_weight_names = set() + tokenizer = None + for adapter_id in params.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + raise ValueError("Base model adapter cannot be merged.") + + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( + load_module_map( + model_id, + adapter_id, + adapter_params.adapter_source, + weight_names, + api_token, + trust_remote_code, + ) + ) + + adapters_to_merge.append((module_map, adapter_config)) + merged_weight_names = merged_weight_names.union(adapter_weight_names) + if tokenizer is None: + tokenizer = adapter_tokenizer + + if len(adapters_to_merge) == 0: + raise ValueError("No adapters to merge.") + + module_map, adapter_config = merge_adapters(adapters_to_merge, params) + return module_map, adapter_config, merged_weight_names, tokenizer + + +def check_architectures( + model_id: str, + adapter_id: str, + adapter_config: "AdapterConfig", + trust_remote_code: bool = False, +): + try: + if not adapter_config.base_model_name_or_path: + # Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None) + return + + expected_config = AutoConfig.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + model_config = AutoConfig.from_pretrained( + adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code + ) + except Exception as e: + warnings.warn( + f"Unable to check architecture compatibility for adapter '{adapter_id}' " + f"against model '{model_id}'. Assuming they are compatible. Error: {e}" + ) + return + + if model_config.architectures == expected_config.architectures: + warnings.warn( + f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " + f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + else: + # TODO(travis): revisit this when we support clasification heads which will not use CausalLM + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + + +@lru_cache(maxsize=128) +def load_module_map( + model_id: str, + adapter_id: str, + adapter_source: str, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + print("adapter_id", adapter_id) + + revision = "main" + + adapter_config = LoraConfig.load(adapter_id, api_token) + if adapter_config.base_model_name_or_path != model_id: + check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) + + adapter_filenames = hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) + + try: + adapter_tokenizer = AutoTokenizer.from_pretrained( + adapter_config.config_path, + token=api_token, + trust_remote_code=trust_remote_code, + ) + except Exception: + # Adapter does not have a tokenizer, so fallback to base model tokenizer + adapter_tokenizer = None + + # load adapter weights from all shards (should have relatively small memory footprint) + adapter_weights = {} + for filename in adapter_filenames: + adapter_weights.update(load_file(filename)) + + # map the model weights to the relevant adapter weights (LoRA A and B matrices) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, weight_names + ) + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer diff --git a/server/text_generation_server/utils/lora.py b/server/text_generation_server/utils/lora.py index 64a6724a..8eed3a97 100644 --- a/server/text_generation_server/utils/lora.py +++ b/server/text_generation_server/utils/lora.py @@ -32,6 +32,7 @@ class LoraConfig: task_type="CAUSAL_LM", use_dora=False, use_rslora=False, + config_path=None, ): self.alpha_pattern = alpha_pattern or {} self.auto_mapping = auto_mapping @@ -57,12 +58,13 @@ class LoraConfig: self.task_type = task_type self.use_dora = use_dora self.use_rslora = use_rslora + self.config_path = config_path @classmethod def from_file(cls, filename): with open(filename, "r") as f: json_data = json.load(f) - return cls(**json_data) + return cls(**json_data, config_path=filename) # TODO: support fetching the model from the hub if it's not in the cache @classmethod diff --git a/server/text_generation_server/utils/merges/strategies.py b/server/text_generation_server/utils/merges/strategies.py new file mode 100644 index 00000000..3b885313 --- /dev/null +++ b/server/text_generation_server/utils/merges/strategies.py @@ -0,0 +1,223 @@ +import copy +from abc import ABC +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union + +import torch + + +class AdapterParameters: + def __init__( + self, adapter_ids, weights, merge_strategy, density, majority_sign_method + ): + self.adapter_ids = adapter_ids + self.weights = weights + self.merge_strategy = merge_strategy + self.density = density + self.majority_sign_method = majority_sign_method + + +from text_generation_server.utils.merges.utils import ( + calculate_majority_sign_mask, + disjoint_merge, + prune, +) + +if TYPE_CHECKING: + from text_generation_server.adapters.lora import LoraConfig + from text_generation_server.utils.adapter import ModuleMap + + +def _apply_weights( + tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor +) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + t = tensors + else: + t = torch.stack(tensors, dim=0) + + # element-wise weighting of each task tensor + # need to unsqueeze weights to match task tensor dimensions + # for multiplication to apply element-wise + while len(t.shape) > len(w.shape): + w = w.unsqueeze(-1) + return t * w + + +class MergeStrategy(ABC): + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() + + +class LinearMerge(MergeStrategy): + def __init__(self, **kwargs): + pass + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class TiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="magnitude") for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + return disjoint_merge(weighted_task_tensors, majority_sign_mask) + + +class DareLinearMerge(MergeStrategy): + def __init__(self, density: float, **kwargs): + self.density = density + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class DareTiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +strategy_registry: Dict[str, Type[MergeStrategy]] = { + "linear": LinearMerge, + "ties": TiesMerge, + "dare_linear": DareLinearMerge, + "dare_ties": DareTiesMerge, +} + + +def merge_adapters( + adapters: List[Tuple["ModuleMap", "LoraConfig"]], + merge_params: AdapterParameters, +) -> Tuple["ModuleMap", "LoraConfig"]: + # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() + strategy_name = "linear" + + weights = merge_params.weights + if not weights: + weights = torch.ones(len(adapters)) + else: + weights = torch.tensor(weights) + + merge_config = { + "density": merge_params.density, + # "majority_sign_method": MajoritySignMethodEnum.Name( + # merge_params.majority_sign_method + # ).lower(), + "majority_sign_method": "total", + } + merge_strategy = strategy_registry[strategy_name](**merge_config) + + module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) + lora_configs = [] + weight_name_to_adapter_idx = defaultdict(list) + + # input is list of (module_map, lora_config) tuples + # convert into dict[k][param_name] -> list of tensors + for idx, (module_map, lora_config) in enumerate(adapters): + for weight_name, data in module_map.items(): + weight_name_to_adapter_idx[weight_name].append(idx) + for k, (param_data, param_name) in data.items(): + module_maps[weight_name][k][param_name].append(param_data) + lora_configs.append(lora_config) + + # validate lora configs are compatible + _validate_lora_configs(lora_configs) + + # merge tensors for each module such that we have a single ModuleMap: + # dict[k] -> merged tensor + merged_module_map: "ModuleMap" = defaultdict(dict) + for weight_name, data in module_maps.items(): + indices = weight_name_to_adapter_idx[weight_name] + param_weights = weights[indices] + for k, param_data in data.items(): + for param_name, tensors in param_data.items(): + merged_tensor = merge_strategy.merge(tensors, param_weights) + merged_module_map[weight_name][k] = (merged_tensor, param_name) + + # merge lora configs + merged_lora_config = _merge_lora_configs(lora_configs) + + return merged_module_map, merged_lora_config + + +def _validate_lora_configs(lora_configs: List["LoraConfig"]): + # check that all configs have the same rank + ranks = set(lora_config.r for lora_config in lora_configs) + if len(ranks) > 1: + raise ValueError( + f"unable to merge adapters, lora configs have different ranks: {ranks}" + ) + + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): + raise ValueError( + "unable to merge adapters, lora configs have no target modules" + ) + + +def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": + merged_lora_config = copy.copy(lora_configs[0]) + + # merge target modules as a union operation + merged_target_modules = sorted( + set( + module + for lora_config in lora_configs + for module in lora_config.target_modules + ) + ) + merged_lora_config.target_modules = merged_target_modules + + return merged_lora_config diff --git a/server/text_generation_server/utils/merges/utils.py b/server/text_generation_server/utils/merges/utils.py new file mode 100644 index 00000000..d9ad3278 --- /dev/null +++ b/server/text_generation_server/utils/merges/utils.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# From: https://github.com/huggingface/peft/pull/1364 +# Copyright 2024-present the HuggingFace Inc. team. +# Modifications by Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.reshape(-1).shape[0]) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, + density: float, + method: Literal["magnitude", "random"], + rescale: bool = False, +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + if density >= 1: + return tensor + elif density < 0: + raise ValueError("Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask( + tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" +): + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = (sign * tensor.abs()).sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors, majority_sign_mask): + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py new file mode 100644 index 00000000..0a50f20f --- /dev/null +++ b/server/text_generation_server/utils/segments.py @@ -0,0 +1,62 @@ +from typing import List, Tuple, Union + +import torch + + +def find_segments( + adapter_indices: Union[torch.Tensor, List[int]] +) -> Tuple[List[int], List[int]]: + segments = [0] + segment_indices = [] + + if isinstance(adapter_indices, torch.Tensor): + # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first + adapter_indices = adapter_indices.cpu().tolist() + + start_index = 0 + for i in range(1, len(adapter_indices)): + if adapter_indices[i] != adapter_indices[i - 1]: + segments.append(i) + segment_indices.append(adapter_indices[i - 1]) + start_index = i + + # Handle the last segment + if start_index < len(adapter_indices): + segments.append(len(adapter_indices)) + segment_indices.append(adapter_indices[-1]) + + return segments, segment_indices + + +class SegmentConcatBuilder: + def __init__(self): + self.adapter_segment_indices = [] + self.adapter_segment_tensors = [] + + def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): + # Update adapter segments + if self.adapter_segment_tensors: + # Because we have already processed at least one batch, remove the 0 start index + # from this batch denoting the beginning of the segment, then offset all segment + # positions by the value of the last segment in the previous batch to account for + # the concatenation. + adapter_segments = ( + adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] + ) + + if ( + self.adapter_segment_indices + and self.adapter_segment_indices[-1] == segment_indices[0] + ): + # If the last segment in the previous batch is the same as the first segment in this batch, + # then we merge them together into a single segment. In effect, this means removing it from + # the segment indices of this batch, and extending the segment span by removing the segment + # end index from the previous batch. + segment_indices = segment_indices[1:] + self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] + + self.adapter_segment_indices.extend(segment_indices) + self.adapter_segment_tensors.append(adapter_segments) + + def build(self) -> Tuple[torch.Tensor, List[int]]: + return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py new file mode 100644 index 00000000..f6551e0f --- /dev/null +++ b/server/text_generation_server/utils/sgmv.py @@ -0,0 +1,242 @@ +import os +import warnings +from functools import lru_cache +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +try: + # TODO: add build steps for Punica kernels + # import punica_kernels as _kernels + import punica.ops as _kernels + + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) +except ImportError: + warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") + _kernels = None + HAS_SGMV = False + + +MIN_SGMV_RANK = 8 +MIN_RANK_CUSTOM = 16 +MAX_RANK_CUSTOM = 128 +SGMV_BLOCK_SIZE = 16 +BGMV_MAX_RANK = 64 + + +def has_sgmv() -> bool: + return HAS_SGMV + + +def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: + """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" + if not has_sgmv(): + return t + + # tensor parallelism will result in effective rank being divided by world_size, + # so we need to scale the min rank to offset that effect + min_rank = MIN_SGMV_RANK * world_size + + # if we're at or below the min rank, pad up to the min rank + # otherwise, pad to the nearest multiple of the block size + current_rank = t.size(dim) + target_rank = ( + min_rank + if current_rank <= min_rank + else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE + ) + if current_rank == target_rank: + return t + + pad_size = target_rank - current_rank + + # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad = [0, 0] * t.dim() + pad[(t.dim() - dim - 1) * 2 + 1] = pad_size + pad = tuple(pad) + + return F.pad(t, pad, mode="constant", value=0.0) + + +def use_cutlass_shrink(lora_rank: int) -> bool: + return lora_rank < MIN_RANK_CUSTOM + + +def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: + if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: + return t.transpose(0, 1) + return t + + +# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py +def add_lora_sgmv_cutlass( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.Tensor, + s_end: torch.Tensor, + layer_idx: int, + lora_rank: int, +): + """ + Semantics: + y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H1]`. + wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H2]`. + s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. + s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. + layer_idx: Layer index of the weight matrices. + """ + if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: + # Custom SGMV shrink only supports rank 16, 32, 64, 128 + _add_lora_sgmv_cutlass_legacy( + y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank + ) + return + + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) + tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) + + +def _add_lora_sgmv_cutlass_legacy( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +): + tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +@lru_cache(maxsize=1) +def get_tmp_tensor(device: torch.device) -> torch.Tensor: + return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) + + +@lru_cache(maxsize=32) +def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: + tmp_size = _kernels.sgmv_cutlass_tmp_size(size) + return torch.empty((tmp_size,), dtype=torch.uint8, device=device) + + +def get_tmp_expand_size(size: int) -> int: + return _kernels.sgmv_cutlass_tmp_size(size) + + +def get_tmp_tensors( + nsegments: int, lora_rank: int, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + if use_cutlass_shrink(lora_rank): + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + else: + tmp_shrink = get_tmp_tensor(device) + tmp_expand = get_tmp_tensor_for_size(nsegments, device) + return tmp_shrink, tmp_expand + + +def lora_a_sgmv_cutlass( + x: torch.Tensor, + tmp: torch.Tensor, + wa_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +) -> torch.Tensor: + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + else: + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + return v + + +def lora_b_sgmv_cutlass( + y: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, +): + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +""" +Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + +Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + v: Shape: `[B, R]`. Temporary vector. + x: Shape: `[B, H1]`. Input vectors. + wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. + wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. +""" + + +def add_lora_a_bgmv( + v: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) + + +def add_lora_b_bgmv( + y: torch.Tensor, + v: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) + + +def segmented_matmul( + y: torch.Tensor, + x: torch.Tensor, + w: List[torch.Tensor], + b: List[torch.Tensor], + s_start: torch.IntTensor, + s_end: torch.IntTensor, +): + for i in range(len(w)): + if s_end[i] - s_start[i] <= 0: + continue + + xi = x[s_start[i] : s_end[i]] + wi = w[i] + bi = b[i] + y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)