feat: prefer lorax implementation and port loading logic

This commit is contained in:
drbh 2024-06-05 23:56:04 +00:00
parent c661631225
commit 8b50f4b779
23 changed files with 2230 additions and 60 deletions

View File

@ -157,7 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
adapter_index: None,
})
.collect();

View File

@ -107,8 +107,8 @@ message Request {
bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
/// LORA adapter id
optional string adapter_id = 8;
/// LORA adapter index
optional uint32 adapter_index = 8;
}
message Batch {

View File

@ -154,7 +154,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
adapter_index: None,
});
n_tokens += max_input_length;

View File

@ -290,7 +290,7 @@ impl State {
entry.request.stopping_parameters.clone(),
)),
top_n_tokens: entry.request.top_n_tokens,
adapter_id: entry.request.adapter_id.clone(),
adapter_index: entry.request.adapter_index,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -430,7 +430,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
adapter_index: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters {
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
pub adapter_index: Option<u32>,
}
fn default_max_new_tokens() -> Option<u32> {
@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters {
seed: None,
top_n_tokens: None,
grammar: None,
adapter_id: None,
adapter_index: None,
}
}

View File

@ -202,7 +202,7 @@ impl Validation {
decoder_input_details,
top_n_tokens,
grammar,
adapter_id,
adapter_index,
..
} = request.parameters;
@ -384,7 +384,7 @@ impl Validation {
parameters,
stopping_parameters,
top_n_tokens,
adapter_id,
adapter_index,
})
}
@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
pub adapter_id: Option<String>,
pub adapter_index: Option<u32>,
}
#[derive(Error, Debug)]

View File

@ -0,0 +1,31 @@
import json
from pathlib import Path
from typing import Dict, Optional
from text_generation_server.adapters.config import AdapterConfig
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
def load_adapter_config(
config_path: Optional[Path],
adapter_config_path: Optional[Path],
api_token: str,
) -> AdapterConfig:
if adapter_config_path is not None and adapter_config_path.exists():
return LoraConfig.load(str(adapter_config_path.parent), api_token)
raise ValueError(
f"No valid adapter config file found: "
f"tried {adapter_config_path} and {config_path}"
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
"load_adapter_config",
]

View File

@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass

View File

@ -0,0 +1,430 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
LORA = "lora"
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
EMPTY_TENSOR = torch.tensor([])
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict,
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
@classmethod
def load(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def key(cls) -> str:
return LORA
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = list(adapter_weights.values())[0]
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v

View File

@ -0,0 +1,159 @@
#############
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
LORA = "lora"
LM_HEAD = "lm_head"
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == LM_HEAD else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get(LORA)
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0

View File

@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)

View File

@ -0,0 +1,244 @@
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
LORA = "lora"
MEDUSA = "medusa"
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(self, base_layer, layer_id, process_group):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor,
input: torch.Tensor,
adapter_data: "AdapterBatchData",
layer_type: str,
start_idx: int,
end_idx: int,
) -> torch.Tensor:
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get(LORA) if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if data.use_sgmv:
# Use SGMV for prefill
if lora_a_ptr is not None and lora_b_ptr is not None:
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
if lora_a_ptr is not None and lora_b_ptr is not None:
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (
(adapter_data.meta.adapter_indices == adapter_index)
.to(input.dtype)
.view(-1, 1)
)
layer_result = self.forward_lora(
input, data, adapter_index, adapter_mask
)
result[:, start_idx:end_idx] += layer_result
return result
def forward_lora(
self,
input: torch.Tensor,
data: "BatchLoraWeights",
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)
result = (a_out @ lora_b) * adapter_mask
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -50,6 +52,16 @@ if SYSTEM == "rocm":
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
# Constants
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
def load_attention(config, prefix, weights):
# Only defined in granite.
@ -57,7 +69,7 @@ def load_attention(config, prefix, weights):
# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
@ -66,7 +78,7 @@ def load_attention(config, prefix, weights):
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
@ -76,7 +88,7 @@ def load_attention(config, prefix, weights):
)
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi(
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
@ -84,6 +96,19 @@ def load_attention(config, prefix, weights):
bias=bias,
)
head_size = config.hidden_size // config.num_attention_heads
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[Q_PROJ, K_PROJ, V_PROJ],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
class FlashLlamaAttention(torch.nn.Module):
def __init__(
@ -124,7 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.adapter_weights = {}
adapter_names = list(lora_weights.keys())
@ -161,12 +186,20 @@ class FlashLlamaAttention(torch.nn.Module):
pre_multiplied_lora_matrix
)
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
O_PROJ,
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -185,8 +218,9 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -197,32 +231,6 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
batch_size = query.size(0)
# hidden states without LoRA
hs_wl = hidden_states[lora_indices == -1]
adapted_query_states = [hs_wl]
adapted_value_states = [hs_wl]
for ind in range(self.n_loras):
mask = lora_indices == ind
hs_sub = hidden_states[mask]
mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0])
mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1])
adapted_query_states.append(mat_q)
adapted_value_states.append(mat_v)
query_adapted = torch.cat(adapted_query_states, dim=0).view(
batch_size, self.num_heads, self.head_size
)
value_adapted = torch.cat(adapted_value_states, dim=0).view(
batch_size, self.num_key_value_heads, self.head_size
)
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
@ -260,7 +268,7 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, index):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -278,26 +286,46 @@ class LlamaMLP(nn.Module):
# Fuse gate and up proj
bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=bias,
)
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=bias,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
[GATE_PROJ, UP_PROJ],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
DOWN_PROJ,
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
@ -337,7 +365,9 @@ class FlashLlamaLayer(nn.Module):
lora_weights=lora_weights,
lora_configs=lora_configs,
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -362,6 +392,7 @@ class FlashLlamaLayer(nn.Module):
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -378,6 +409,7 @@ class FlashLlamaLayer(nn.Module):
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
)
# faster post attention rms norm
@ -440,6 +472,7 @@ class FlashLlamaModel(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor],
batch_lora_adapter_mask: Optional[List[str]],
lora_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -464,6 +497,7 @@ class FlashLlamaModel(torch.nn.Module):
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -512,6 +546,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
batch_lora_adapter_mask: Optional[List[str]] = None,
lora_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -527,6 +562,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices=prefill_cache_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
empty_cache,
@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
adapter_meta: AdapterBatchMetadata
# Number of blocks in this batch
num_blocks: int
# Maximum number of blocks
@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
adapter_indices_list = []
adapter_set = set()
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
@ -225,6 +233,9 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
adapter_set.add(r.adapter_index)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
@ -296,6 +307,10 @@ class FlashCausalLMBatch(Batch):
max_length, input_length + max_new_tokens + speculative_length
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
@ -339,6 +354,11 @@ class FlashCausalLMBatch(Batch):
input_lengths, dtype=torch.int32, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
@ -393,6 +413,12 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
speculative_ids=None,
)
@ -443,6 +469,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
num_blocks = 0
max_blocks = 0
@ -471,6 +498,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx])
adapter_set.add(self.requests[idx].adapter_index)
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -498,6 +527,7 @@ class FlashCausalLMBatch(Batch):
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
@ -513,6 +543,11 @@ class FlashCausalLMBatch(Batch):
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
return type(self)(
batch_id=self.batch_id,
requests=requests,
@ -543,6 +578,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)
@classmethod
@ -596,6 +637,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = []
block_tables = []
@ -613,6 +662,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -637,6 +687,18 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
@ -680,6 +742,8 @@ class FlashCausalLMBatch(Batch):
else None
)
_adapter_segments, _adapter_segment_indices = adapter_segment_builder.build()
return cls(
batch_id=batches[0].batch_id,
requests=requests,
@ -719,6 +783,7 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
@ -738,6 +803,7 @@ class FlashCausalLM(Model):
self.kv_cache = []
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
@ -996,7 +1062,7 @@ class FlashCausalLM(Model):
)
def forward(
self, batch: FlashCausalLMBatch
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
@ -1066,13 +1132,6 @@ class FlashCausalLM(Model):
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
for i, r in enumerate(batch.requests):
if r.adapter_id:
lora_index = self.model.get_lora_index(r.adapter_id)
input_length = batch.input_lengths[i]
lora_indices[i : i + input_length] = lora_index
batch_lora_adapter_mask[i] = True
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -1087,6 +1146,7 @@ class FlashCausalLM(Model):
lm_head_indices=lm_head_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
adapter_data=adapter_data,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
@ -1123,7 +1183,34 @@ class FlashCausalLM(Model):
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
out, speculative_logits = self.forward(batch)
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = (
adapter_meta.adapter_indices.unsqueeze(-1)
.expand(B, new_length)
.reshape(-1)
)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices,
)
out, speculative_logits = self.forward(batch, adapter_data)
if prefill:
next_token_logits = (

View File

@ -4,7 +4,7 @@ import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
@ -22,6 +22,30 @@ tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.lora import LoraConfig
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
LM_HEAD = "lm_head"
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [
Q_PROJ,
K_PROJ,
V_PROJ,
O_PROJ,
GATE_PROJ,
UP_PROJ,
DOWN_PROJ,
] # LM_HEAD
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
class FlashLlama(FlashCausalLM):
def __init__(
@ -80,6 +104,7 @@ class FlashLlama(FlashCausalLM):
)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
@ -90,3 +115,59 @@ class FlashLlama(FlashCausalLM):
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, K_PROJ)] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, V_PROJ)] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, O_PROJ)] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, GATE_PROJ)] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, UP_PROJ)] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, DOWN_PROJ)] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -2,12 +2,50 @@ import inspect
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.utils.adapter import (
load_and_merge_adapters,
AdapterParameters,
AdapterSource,
)
from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__"
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
B = TypeVar("B", bound=Batch)
@ -15,6 +53,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
@ -25,6 +64,7 @@ class Model(ABC):
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
):
self.model_id = model_id
self.model = model.eval()
self.tokenizer = tokenizer
@ -42,6 +82,12 @@ class Model(ABC):
self.world_size = world_size
self.sliding_window = sliding_window if sliding_window != -1 else None
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.target_to_layer = self.adapter_target_to_layer()
self.loaded_adapters = set()
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
@ -119,3 +165,156 @@ class Model(ABC):
raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
)
@property
def supports_adapter_loading(self) -> bool:
return False
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
return {}
@property
def adapter_layers(self) -> List[str]:
return []
@property
def default_traced_adapter_layers(self) -> List[str]:
return []
def get_num_layers_for_type(self, layer_type: str) -> int:
return 0
def is_row_parallel(self, layer_type: str) -> bool:
return False
@property
def max_speculative_tokens(self) -> int:
return max(
[
weights.max_speculative_tokens
for weights in self.layer_to_adapter_weights.values()
],
default=0,
)
def load_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
api_token: str,
dynamic: bool = True,
):
"""Loads adapter weights from disk / host memory on the GPU.
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if dynamic and not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
logger.info(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
)
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
self.model_id,
adapter_parameters,
adapter_source,
adapter_index,
weight_names,
api_token,
False,
)
unused_weight_names = adapter_weight_names.copy()
for layer_name in self.adapter_layers:
adapter_weights = adapter_config.load_batched_adapter_weights(
self,
module_map,
layer_name,
unused_weight_names,
dynamic,
)
if adapter_weights is None:
continue
layer_weights = self.layer_to_adapter_weights[layer_name]
layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
self.loaded_adapters.add(adapter_index)
def shard_lora_weights(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
layer_type: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
split_dim = 0 if self.is_row_parallel(layer_type) else 1
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
for w in weights_a
]
# [r, hidden_size]
weights_b = [
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
]
return weights_a, weights_b
def offload_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
for layer_name in self.adapter_layers:
if layer_name in self.layer_to_adapter_weights:
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
self.loaded_adapters.remove(adapter_index)

View File

@ -30,6 +30,9 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id
from text_generation_server.utils.adapter import (
AdapterParameters,
)
class SignalHandler:
@ -235,6 +238,30 @@ def serve(
trust_remote_code,
max_input_tokens,
)
# TODO: avoid hacky hardcoded adapter id
adapter_parameters = AdapterParameters(
adapter_ids=lora_adapter_ids,
weights=[
# TODO: fill with actual weights
torch.tensor([1.0], dtype=torch.float32)
],
merge_strategy=0,
density=0.0,
majority_sign_method=0,
)
adapter_source = None
adapter_index = None
api_token = None
model.load_adapter(
adapter_parameters,
adapter_source,
adapter_index,
api_token,
False,
)
except Exception:
logger.exception("Error when initializing model")
raise

View File

@ -0,0 +1,196 @@
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.utils.merges.strategies import merge_adapters
from text_generation_server.utils import hub
from text_generation_server.adapters.lora import LoraConfig
if TYPE_CHECKING:
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
BASE_MODEL_ADAPTER_ID = "__base_model__"
class AdapterParameters:
def __init__(
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
):
self.adapter_ids = adapter_ids
self.weights = weights
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
class AdapterSource:
def __init__(self, adapter_id: str, model_id: str, revision: str):
self.adapter_id = adapter_id
self.model_id = model_id
self.revision = revision
def load_and_merge_adapters(
model_id: str,
adapter_parameters: AdapterParameters,
adapter_source: str,
adapter_index: int,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
return load_module_map(
model_id,
adapter_parameters.adapter_ids[0],
adapter_source,
weight_names,
api_token,
trust_remote_code,
)
adapter_params = AdapterParametersContainer(
adapter_parameters, adapter_source, adapter_index
)
return _load_and_merge(
model_id, adapter_params, weight_names, api_token, trust_remote_code
)
class AdapterParametersContainer:
def __init__(self, adapter_parameters, adapter_source, adapter_index):
self.adapter_parameters = adapter_parameters
self.adapter_source = adapter_source
self.adapter_index = adapter_index
def __hash__(self) -> int:
return self.adapter_index
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters
adapters_to_merge = []
merged_weight_names = set()
tokenizer = None
for adapter_id in params.adapter_ids:
if adapter_id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter_id,
adapter_params.adapter_source,
weight_names,
api_token,
trust_remote_code,
)
)
adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names)
if tokenizer is None:
tokenizer = adapter_tokenizer
if len(adapters_to_merge) == 0:
raise ValueError("No adapters to merge.")
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
return module_map, adapter_config, merged_weight_names, tokenizer
def check_architectures(
model_id: str,
adapter_id: str,
adapter_config: "AdapterConfig",
trust_remote_code: bool = False,
):
try:
if not adapter_config.base_model_name_or_path:
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return
expected_config = AutoConfig.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
model_config = AutoConfig.from_pretrained(
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
)
except Exception as e:
warnings.warn(
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
)
return
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
adapter_id: str,
adapter_source: str,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
print("adapter_id", adapter_id)
revision = "main"
adapter_config = LoraConfig.load(adapter_id, api_token)
if adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = hub._cached_adapter_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path,
token=api_token,
trust_remote_code=trust_remote_code,
)
except Exception:
# Adapter does not have a tokenizer, so fallback to base model tokenizer
adapter_tokenizer = None
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
adapter_weights, weight_names
)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer

View File

@ -32,6 +32,7 @@ class LoraConfig:
task_type="CAUSAL_LM",
use_dora=False,
use_rslora=False,
config_path=None,
):
self.alpha_pattern = alpha_pattern or {}
self.auto_mapping = auto_mapping
@ -57,12 +58,13 @@ class LoraConfig:
self.task_type = task_type
self.use_dora = use_dora
self.use_rslora = use_rslora
self.config_path = config_path
@classmethod
def from_file(cls, filename):
with open(filename, "r") as f:
json_data = json.load(f)
return cls(**json_data)
return cls(**json_data, config_path=filename)
# TODO: support fetching the model from the hub if it's not in the cache
@classmethod

View File

@ -0,0 +1,223 @@
import copy
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
import torch
class AdapterParameters:
def __init__(
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
):
self.adapter_ids = adapter_ids
self.weights = weights
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor:
if isinstance(tensors, torch.Tensor):
t = tensors
else:
t = torch.stack(tensors, dim=0)
# element-wise weighting of each task tensor
# need to unsqueeze weights to match task tensor dimensions
# for multiplication to apply element-wise
while len(t.shape) > len(w.shape):
w = w.unsqueeze(-1)
return t * w
class MergeStrategy(ABC):
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
raise NotImplementedError()
class LinearMerge(MergeStrategy):
def __init__(self, **kwargs):
pass
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class TiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
class DareLinearMerge(MergeStrategy):
def __init__(self, density: float, **kwargs):
self.density = density
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class DareTiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
return mixed_task_tensors
strategy_registry: Dict[str, Type[MergeStrategy]] = {
"linear": LinearMerge,
"ties": TiesMerge,
"dare_linear": DareLinearMerge,
"dare_ties": DareTiesMerge,
}
def merge_adapters(
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
merge_params: AdapterParameters,
) -> Tuple["ModuleMap", "LoraConfig"]:
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
strategy_name = "linear"
weights = merge_params.weights
if not weights:
weights = torch.ones(len(adapters))
else:
weights = torch.tensor(weights)
merge_config = {
"density": merge_params.density,
# "majority_sign_method": MajoritySignMethodEnum.Name(
# merge_params.majority_sign_method
# ).lower(),
"majority_sign_method": "total",
}
merge_strategy = strategy_registry[strategy_name](**merge_config)
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(list))
)
lora_configs = []
weight_name_to_adapter_idx = defaultdict(list)
# input is list of (module_map, lora_config) tuples
# convert into dict[k][param_name] -> list of tensors
for idx, (module_map, lora_config) in enumerate(adapters):
for weight_name, data in module_map.items():
weight_name_to_adapter_idx[weight_name].append(idx)
for k, (param_data, param_name) in data.items():
module_maps[weight_name][k][param_name].append(param_data)
lora_configs.append(lora_config)
# validate lora configs are compatible
_validate_lora_configs(lora_configs)
# merge tensors for each module such that we have a single ModuleMap:
# dict[k] -> merged tensor
merged_module_map: "ModuleMap" = defaultdict(dict)
for weight_name, data in module_maps.items():
indices = weight_name_to_adapter_idx[weight_name]
param_weights = weights[indices]
for k, param_data in data.items():
for param_name, tensors in param_data.items():
merged_tensor = merge_strategy.merge(tensors, param_weights)
merged_module_map[weight_name][k] = (merged_tensor, param_name)
# merge lora configs
merged_lora_config = _merge_lora_configs(lora_configs)
return merged_module_map, merged_lora_config
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
# check that all configs have the same rank
ranks = set(lora_config.r for lora_config in lora_configs)
if len(ranks) > 1:
raise ValueError(
f"unable to merge adapters, lora configs have different ranks: {ranks}"
)
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
raise ValueError(
"unable to merge adapters, lora configs have no target modules"
)
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
merged_lora_config = copy.copy(lora_configs[0])
# merge target modules as a union operation
merged_target_modules = sorted(
set(
module
for lora_config in lora_configs
for module in lora_config.target_modules
)
)
merged_lora_config.target_modules = merged_target_modules
return merged_lora_config

View File

@ -0,0 +1,108 @@
# coding=utf-8
# From: https://github.com/huggingface/peft/pull/1364
# Copyright 2024-present the HuggingFace Inc. team.
# Modifications by Predibase, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import torch
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
"""
mask = torch.zeros_like(tensor).reshape(-1)
k = int(density * tensor.reshape(-1).shape[0])
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
mask[top_k[1]] = 1
return tensor * mask.reshape(tensor.shape)
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
pruned_tensor = tensor * mask
if rescale:
torch.div(input=pruned_tensor, other=density)
return pruned_tensor
def prune(
tensor: torch.Tensor,
density: float,
method: Literal["magnitude", "random"],
rescale: bool = False,
) -> torch.Tensor:
"""
Prune the values of task tensors based on the `method`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
if density >= 1:
return tensor
elif density < 0:
raise ValueError("Density should be >= 0, got {density}")
if method == "magnitude":
return magnitude_based_pruning(tensor, density)
elif method == "random":
return random_pruning(tensor, density, rescale=rescale)
else:
raise ValueError(f"Unknown method {method}")
def calculate_majority_sign_mask(
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
):
"""
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
Args:
tensor (`torch.Tensor`):The tensor to get the mask from.
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
"""
sign = tensor.sign()
if method == "total":
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
elif method == "frequency":
sign_magnitude = sign.sum(dim=0)
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
return sign == majority_sign
def disjoint_merge(task_tensors, majority_sign_mask):
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
num_params_preserved = majority_sign_mask.sum(dim=0)
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)

View File

@ -0,0 +1,62 @@
from typing import List, Tuple, Union
import torch
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []
if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu().tolist()
start_index = 0
for i in range(1, len(adapter_indices)):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1])
start_index = i
# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
return segments, segment_indices
class SegmentConcatBuilder:
def __init__(self):
self.adapter_segment_indices = []
self.adapter_segment_tensors = []
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
# Update adapter segments
if self.adapter_segment_tensors:
# Because we have already processed at least one batch, remove the 0 start index
# from this batch denoting the beginning of the segment, then offset all segment
# positions by the value of the last segment in the previous batch to account for
# the concatenation.
adapter_segments = (
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
)
if (
self.adapter_segment_indices
and self.adapter_segment_indices[-1] == segment_indices[0]
):
# If the last segment in the previous batch is the same as the first segment in this batch,
# then we merge them together into a single segment. In effect, this means removing it from
# the segment indices of this batch, and extending the segment span by removing the segment
# end index from the previous batch.
segment_indices = segment_indices[1:]
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
self.adapter_segment_indices.extend(segment_indices)
self.adapter_segment_tensors.append(adapter_segments)
def build(self) -> Tuple[torch.Tensor, List[int]]:
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices

View File

@ -0,0 +1,242 @@
import os
import warnings
from functools import lru_cache
from typing import List, Tuple
import torch
import torch.nn.functional as F
try:
# TODO: add build steps for Punica kernels
# import punica_kernels as _kernels
import punica.ops as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
_kernels = None
HAS_SGMV = False
MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 64
def has_sgmv() -> bool:
return HAS_SGMV
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
if not has_sgmv():
return t
# tensor parallelism will result in effective rank being divided by world_size,
# so we need to scale the min rank to offset that effect
min_rank = MIN_SGMV_RANK * world_size
# if we're at or below the min rank, pad up to the min rank
# otherwise, pad to the nearest multiple of the block size
current_rank = t.size(dim)
target_rank = (
min_rank
if current_rank <= min_rank
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
)
if current_rank == target_rank:
return t
pad_size = target_rank - current_rank
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
pad = [0, 0] * t.dim()
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
pad = tuple(pad)
return F.pad(t, pad, mode="constant", value=0.0)
def use_cutlass_shrink(lora_rank: int) -> bool:
return lora_rank < MIN_RANK_CUSTOM
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
return t.transpose(0, 1)
return t
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.Tensor,
s_end: torch.Tensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
layer_idx: Layer index of the weight matrices.
"""
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
# Custom SGMV shrink only supports rank 16, 32, 64, 128
_add_lora_sgmv_cutlass_legacy(
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
)
return
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
def _add_lora_sgmv_cutlass_legacy(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
@lru_cache(maxsize=32)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
def get_tmp_expand_size(size: int) -> int:
return _kernels.sgmv_cutlass_tmp_size(size)
def get_tmp_tensors(
nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
if use_cutlass_shrink(lora_rank):
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
else:
tmp_shrink = get_tmp_tensor(device)
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
return tmp_shrink, tmp_expand
def lora_a_sgmv_cutlass(
x: torch.Tensor,
tmp: torch.Tensor,
wa_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
) -> torch.Tensor:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
else:
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v
def lora_b_sgmv_cutlass(
y: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
):
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
v: Shape: `[B, R]`. Temporary vector.
x: Shape: `[B, H1]`. Input vectors.
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
"""
def add_lora_a_bgmv(
v: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
def add_lora_b_bgmv(
y: torch.Tensor,
v: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
def segmented_matmul(
y: torch.Tensor,
x: torch.Tensor,
w: List[torch.Tensor],
b: List[torch.Tensor],
s_start: torch.IntTensor,
s_end: torch.IntTensor,
):
for i in range(len(w)):
if s_end[i] - s_start[i] <= 0:
continue
xi = x[s_start[i] : s_end[i]]
wi = w[i]
bi = b[i]
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)