feat: prefer lorax implementation and port loading logic
This commit is contained in:
parent
c661631225
commit
8b50f4b779
|
@ -157,7 +157,7 @@ async fn prefill(
|
|||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
adapter_id: None,
|
||||
adapter_index: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -107,8 +107,8 @@ message Request {
|
|||
bool prefill_logprobs = 6;
|
||||
/// Return most likely n tokens
|
||||
uint32 top_n_tokens = 7;
|
||||
/// LORA adapter id
|
||||
optional string adapter_id = 8;
|
||||
/// LORA adapter index
|
||||
optional uint32 adapter_index = 8;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
|
|
@ -154,7 +154,7 @@ impl Client {
|
|||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
adapter_index: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
|
|
|
@ -290,7 +290,7 @@ impl State {
|
|||
entry.request.stopping_parameters.clone(),
|
||||
)),
|
||||
top_n_tokens: entry.request.top_n_tokens,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
adapter_index: entry.request.adapter_index,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
@ -430,7 +430,7 @@ mod tests {
|
|||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
adapter_index: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
|
|
@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters {
|
|||
/// Lora adapter id
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub adapter_id: Option<String>,
|
||||
pub adapter_index: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
|
@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
seed: None,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: None,
|
||||
adapter_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -202,7 +202,7 @@ impl Validation {
|
|||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
grammar,
|
||||
adapter_id,
|
||||
adapter_index,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
|
@ -384,7 +384,7 @@ impl Validation {
|
|||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens,
|
||||
adapter_id,
|
||||
adapter_index,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
pub top_n_tokens: u32,
|
||||
pub adapter_id: Option<String>,
|
||||
pub adapter_index: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchData,
|
||||
AdapterBatchMetadata,
|
||||
)
|
||||
|
||||
|
||||
def load_adapter_config(
|
||||
config_path: Optional[Path],
|
||||
adapter_config_path: Optional[Path],
|
||||
api_token: str,
|
||||
) -> AdapterConfig:
|
||||
if adapter_config_path is not None and adapter_config_path.exists():
|
||||
return LoraConfig.load(str(adapter_config_path.parent), api_token)
|
||||
|
||||
raise ValueError(
|
||||
f"No valid adapter config file found: "
|
||||
f"tried {adapter_config_path} and {config_path}"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdapterBatchData",
|
||||
"AdapterBatchMetadata",
|
||||
"load_adapter_config",
|
||||
]
|
|
@ -0,0 +1,37 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from text_generation_server.adapters.weights import AdapterWeights
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig(ABC):
|
||||
base_model_name_or_path: str
|
||||
|
||||
@abstractmethod
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict,
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
pass
|
|
@ -0,0 +1,430 @@
|
|||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
LORA = "lora"
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchMetadata,
|
||||
AdapterWeights,
|
||||
BatchAdapterWeights,
|
||||
)
|
||||
from text_generation_server.utils.sgmv import (
|
||||
BGMV_MAX_RANK,
|
||||
MAX_RANK_CUSTOM,
|
||||
get_tmp_tensors,
|
||||
orient_for_rank,
|
||||
pad_rank,
|
||||
use_cutlass_shrink,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
EMPTY_TENSOR = torch.tensor([])
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(AdapterConfig):
|
||||
r: int
|
||||
target_modules: Optional[Union[List[str], str]]
|
||||
fan_in_fan_out: bool
|
||||
lora_alpha: int
|
||||
use_rslora: bool
|
||||
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict,
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
adapter_weight_names = set()
|
||||
module_map = {}
|
||||
for weight_name in weight_names:
|
||||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||
continue
|
||||
|
||||
module_map[weight_name] = {
|
||||
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||
}
|
||||
adapter_weight_names.add(lora_a_name)
|
||||
adapter_weight_names.add(lora_b_name)
|
||||
return module_map, adapter_weight_names
|
||||
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
return LoraWeights.load(
|
||||
self,
|
||||
model,
|
||||
module_map,
|
||||
layer_type,
|
||||
unused_weight_names,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||
return cls(
|
||||
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||
r=hf_config.r,
|
||||
target_modules=hf_config.target_modules,
|
||||
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||
lora_alpha=hf_config.lora_alpha,
|
||||
use_rslora=(
|
||||
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraWeights(AdapterWeights):
|
||||
"""LoRA weights for a single adapter merged across all layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
adapter_config: LoraConfig,
|
||||
):
|
||||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||
|
||||
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||
self._is_transposed = False
|
||||
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
|
||||
# [num_layers, r, hidden_size]
|
||||
self._weights_b = torch.stack(weights_b)
|
||||
|
||||
self.adapter_config = adapter_config
|
||||
|
||||
@property
|
||||
def weights_a(self) -> torch.Tensor:
|
||||
if self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_a
|
||||
|
||||
@property
|
||||
def weights_b(self) -> torch.Tensor:
|
||||
if self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_b
|
||||
|
||||
@property
|
||||
def weights_a_t(self) -> torch.Tensor:
|
||||
if not self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_a
|
||||
|
||||
@property
|
||||
def weights_b_t(self) -> torch.Tensor:
|
||||
if not self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_b
|
||||
|
||||
def _transpose_weights(self):
|
||||
if self._use_cutlass_shrink:
|
||||
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||
self._is_transposed = not self._is_transposed
|
||||
|
||||
@classmethod
|
||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||
return [BatchLoraWeights]
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
config: LoraConfig,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
) -> Optional[AdapterWeights]:
|
||||
nlayers = model.get_num_layers_for_type(layer_type)
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
weight_name, layer = model.target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
||||
if weight_name not in module_map:
|
||||
# There is no LoRA weight for this layer type in the adapter
|
||||
return None
|
||||
|
||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||
lora_a = lora_a.to(base_device, model.dtype)
|
||||
|
||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||
lora_b = lora_b.to(base_device, model.dtype)
|
||||
|
||||
scale = get_scaling_factor(
|
||||
config.lora_alpha,
|
||||
config.r,
|
||||
uses_rslora=config.use_rslora,
|
||||
)
|
||||
|
||||
unused_weight_names.discard(lora_a_name)
|
||||
unused_weight_names.discard(lora_b_name)
|
||||
|
||||
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||
# (A * B) * C = A * (B * C)
|
||||
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [
|
||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||
]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
padded_rank = lora_a_list[0].size(1)
|
||||
config.r = padded_rank
|
||||
|
||||
return LoraWeights(
|
||||
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankSegments:
|
||||
rank: int
|
||||
|
||||
lora_a_ptr: torch.Tensor
|
||||
lora_b_ptr: torch.Tensor
|
||||
|
||||
# prefill (sgmv)
|
||||
tmp_shrink: torch.Tensor
|
||||
tmp_expand: torch.Tensor
|
||||
segment_starts: torch.Tensor
|
||||
segment_ends: torch.Tensor
|
||||
|
||||
# decode (bgmv)
|
||||
indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchLoraWeights(BatchAdapterWeights):
|
||||
lora_a: Dict[int, torch.Tensor]
|
||||
lora_b: Dict[int, torch.Tensor]
|
||||
adapter_index_configs: Dict[int, LoraConfig]
|
||||
rank_data: Dict[int, RankSegments]
|
||||
use_sgmv: bool
|
||||
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
return adapter_index in self.adapter_index_configs
|
||||
|
||||
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||
return all(
|
||||
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||
for rank_data in self.rank_data.values()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return LORA
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Optional["BatchLoraWeights"]:
|
||||
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||
adapter_weights = {
|
||||
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||
}
|
||||
if not adapter_weights:
|
||||
return None
|
||||
|
||||
first_weights = list(adapter_weights.values())[0]
|
||||
device = first_weights.weights_a.device
|
||||
segment_indices = meta.segment_indices
|
||||
|
||||
lora_a = {
|
||||
idx: adapter_weights[idx].weights_a
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
lora_b = {
|
||||
idx: adapter_weights[idx].weights_b
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
|
||||
max_rank = max(
|
||||
adapter_weights[idx].lora_a_r
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
)
|
||||
|
||||
if prefill or max_rank > BGMV_MAX_RANK:
|
||||
use_sgmv = True
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_a.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
lora_b_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_b.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
use_sgmv = False
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_a_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
lora_b_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_b_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else EMPTY_TENSOR.data_ptr()
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
adapter_index_configs = {
|
||||
idx: adapter_weights[idx].adapter_config
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
|
||||
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||
|
||||
rank_indices = defaultdict(list)
|
||||
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||
if adapter_idx not in adapter_weights:
|
||||
continue
|
||||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||
|
||||
if prefill_head_indices is not None:
|
||||
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||
for head_index in prefill_head_indices:
|
||||
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||
if head_index < meta.adapter_segments[j]:
|
||||
prefill_head_segment_ends[-1] += 1
|
||||
else:
|
||||
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||
j += 1
|
||||
|
||||
rank_data = {}
|
||||
for rank, indices in rank_indices.items():
|
||||
tmp_shrink = None
|
||||
tmp_expand = None
|
||||
segment_starts = None
|
||||
segment_ends = None
|
||||
batch_indices = None
|
||||
|
||||
if use_sgmv:
|
||||
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||
lora_a_ptr_indices.size(0), rank, device
|
||||
)
|
||||
segment_starts = meta.adapter_segments[indices]
|
||||
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||
if prefill_head_indices is not None:
|
||||
for i, segment_index in enumerate(indices):
|
||||
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||
else:
|
||||
rank_indices = set(indices)
|
||||
batch_indices = [
|
||||
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||
]
|
||||
batch_indices = [
|
||||
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||
]
|
||||
batch_indices = torch.tensor(
|
||||
batch_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
rank_data[rank] = RankSegments(
|
||||
rank=rank,
|
||||
tmp_shrink=tmp_shrink,
|
||||
tmp_expand=tmp_expand,
|
||||
lora_a_ptr=lora_a_ptr[indices],
|
||||
lora_b_ptr=lora_b_ptr[indices],
|
||||
segment_starts=segment_starts,
|
||||
segment_ends=segment_ends,
|
||||
indices=batch_indices,
|
||||
)
|
||||
|
||||
return BatchLoraWeights(
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
adapter_index_configs=adapter_index_configs,
|
||||
rank_data=rank_data,
|
||||
use_sgmv=use_sgmv,
|
||||
)
|
||||
|
||||
|
||||
def get_scaling_factor(
|
||||
lora_alpha: int,
|
||||
r: int,
|
||||
uses_rslora: bool = False,
|
||||
) -> float:
|
||||
"""Computes the scaling factor for the lora weights."""
|
||||
if uses_rslora:
|
||||
return lora_alpha / (r**0.5)
|
||||
return lora_alpha / r
|
||||
|
||||
|
||||
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||
if hasattr(v, "lora_weights"):
|
||||
return v.lora_weights
|
||||
return v
|
|
@ -0,0 +1,159 @@
|
|||
#############
|
||||
from abc import ABC, abstractclassmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
LORA = "lora"
|
||||
LM_HEAD = "lm_head"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchMetadata:
|
||||
# [batch_size]
|
||||
adapter_indices: torch.Tensor
|
||||
|
||||
# [num_adapters]
|
||||
adapter_set: Set[int]
|
||||
|
||||
# [num_segments + 1]
|
||||
adapter_segments: torch.Tensor
|
||||
|
||||
# [num_segments]
|
||||
# maps from segment index to adapter index, i.e.:
|
||||
# segment_indices[s] == adapter_indices[i]
|
||||
segment_indices: List[int]
|
||||
|
||||
|
||||
class AdapterWeights(ABC):
|
||||
@abstractclassmethod
|
||||
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def speculative_tokens(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class BatchAdapterWeights(ABC):
|
||||
@abstractclassmethod
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def key(cls) -> str:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def load(
|
||||
cls,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: "AdapterBatchMetadata",
|
||||
prefill: bool,
|
||||
prefill_head_indices: torch.Tensor,
|
||||
) -> Optional["BatchAdapterWeights"]:
|
||||
pass
|
||||
|
||||
|
||||
class LayerAdapterWeights:
|
||||
"""Adapter weights that apply to a particular layer."""
|
||||
|
||||
def __init__(self):
|
||||
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||
|
||||
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||
self.adapter_weights[adapter_idx] = weights
|
||||
|
||||
def remove_adapter(self, adapter_idx: int):
|
||||
if adapter_idx not in self.adapter_weights:
|
||||
return
|
||||
del self.adapter_weights[adapter_idx]
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
adapter_weights.speculative_tokens
|
||||
for adapter_weights in self.adapter_weights.values()
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.adapter_weights) == 0
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Dict[str, BatchAdapterWeights]:
|
||||
# bucket adapters by batch class
|
||||
adapter_batch_types: Dict[
|
||||
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||
] = defaultdict(dict)
|
||||
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||
for batch_type in adapter_weights.get_batch_types():
|
||||
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||
|
||||
batch_data = {}
|
||||
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||
batched_weights = batch_type.load(
|
||||
adapter_weights, meta, prefill, prefill_head_indices
|
||||
)
|
||||
if batched_weights is not None:
|
||||
batch_data[batch_type.key()] = batched_weights
|
||||
return batch_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchData:
|
||||
meta: AdapterBatchMetadata
|
||||
|
||||
# layer type -> adapter type -> batch weight data
|
||||
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||
|
||||
prefill: bool
|
||||
|
||||
@staticmethod
|
||||
def from_meta(
|
||||
meta: AdapterBatchMetadata,
|
||||
weights: Dict[str, LayerAdapterWeights],
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> "AdapterBatchData":
|
||||
data = {}
|
||||
for k, v in weights.items():
|
||||
if v.is_empty():
|
||||
continue
|
||||
data[k] = v.get_data(
|
||||
meta, prefill, prefill_head_indices if k == LM_HEAD else None
|
||||
)
|
||||
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||
|
||||
def ranks(self) -> Set[int]:
|
||||
# TODO(travis): refactor to be less coupled to lora implementation
|
||||
ranks = set()
|
||||
for layer_data in self.data.values():
|
||||
lora_data = layer_data.get(LORA)
|
||||
if lora_data is None:
|
||||
continue
|
||||
|
||||
for rank_data in lora_data.rank_data.values():
|
||||
ranks.add(rank_data.rank)
|
||||
|
||||
return ranks
|
||||
|
||||
def layer_names(self) -> Set[str]:
|
||||
return set(self.data.keys())
|
||||
|
||||
def adapter_keys(self) -> Set[str]:
|
||||
adapter_keys = set()
|
||||
for layer_data in self.data.values():
|
||||
adapter_keys.update(layer_data.keys())
|
||||
return adapter_keys
|
||||
|
||||
@property
|
||||
def max_rank(self) -> int:
|
||||
ranks = self.ranks()
|
||||
return max(ranks) if len(ranks) > 0 else 0
|
|
@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
|||
# Just to add the `load` methods.
|
||||
from text_generation_server.layers.layernorm import load_layer_norm
|
||||
from text_generation_server.layers.conv import load_conv2d
|
||||
|
||||
from text_generation_server.layers.lora import (
|
||||
LoraLinear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import init_empty_weights
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from text_generation_server.utils.sgmv import (
|
||||
add_lora_a_bgmv,
|
||||
add_lora_b_bgmv,
|
||||
has_sgmv,
|
||||
lora_a_sgmv_cutlass,
|
||||
lora_b_sgmv_cutlass,
|
||||
orient_for_rank,
|
||||
)
|
||||
|
||||
LORA = "lora"
|
||||
MEDUSA = "medusa"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
def __init__(self, base_layer, layer_id, process_group):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.layer_id = layer_id
|
||||
self.process_group = process_group
|
||||
|
||||
def forward_layer_type(
|
||||
self,
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
adapter_data: "AdapterBatchData",
|
||||
layer_type: str,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
) -> torch.Tensor:
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get(LORA) if data is not None else None
|
||||
)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
if lora_a_ptr is not None and lora_b_ptr is not None:
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
if lora_a_ptr is not None and lora_b_ptr is not None:
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
else:
|
||||
for adapter_index in adapter_data.meta.adapter_set:
|
||||
if data is not None and data.has_adapter(adapter_index):
|
||||
adapter_mask = (
|
||||
(adapter_data.meta.adapter_indices == adapter_index)
|
||||
.to(input.dtype)
|
||||
.view(-1, 1)
|
||||
)
|
||||
layer_result = self.forward_lora(
|
||||
input, data, adapter_index, adapter_mask
|
||||
)
|
||||
result[:, start_idx:end_idx] += layer_result
|
||||
|
||||
return result
|
||||
|
||||
def forward_lora(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
data: "BatchLoraWeights",
|
||||
adapter_index: int,
|
||||
adapter_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||
|
||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||
|
||||
a_out = input @ lora_a
|
||||
if self.process_group.size() > 1:
|
||||
a_out = self.collect_lora_a(a_out)
|
||||
|
||||
result = (a_out @ lora_b) * adapter_mask
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("Implemented in subclasses")
|
||||
|
||||
|
||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_names = layer_names
|
||||
self.sizes = sizes
|
||||
|
||||
@classmethod
|
||||
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
|
||||
return TensorParallelMultiAdapterLinear(
|
||||
base_layer, layer_id, layer_names, sizes, process_group
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
# handle models like Bloom that have inputs of shape
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||
# for the LoRA computation, then reshape back
|
||||
prev_shape = result.shape
|
||||
is_3d = len(input.shape) >= 3
|
||||
if is_3d:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
|
||||
offset = 0
|
||||
for i, layer_name in enumerate(self.layer_names):
|
||||
start_idx = offset // self.process_group.size()
|
||||
|
||||
if self.sizes is not None:
|
||||
offset += self.sizes[i]
|
||||
end_idx = offset // self.process_group.size()
|
||||
else:
|
||||
end_idx = result.shape[1]
|
||||
|
||||
result = self.forward_layer_type(
|
||||
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
if is_3d:
|
||||
result = result.reshape(prev_shape)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
gathered_tensors = [
|
||||
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||
return torch.cat(gathered_tensors, dim=1)
|
||||
|
||||
|
||||
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_name = layer_name
|
||||
|
||||
@classmethod
|
||||
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||
return cls(base_layer, layer_id, layer_name, process_group)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||
stride = result.shape[-1] // self.process_group.size()
|
||||
start_idx = self.process_group.rank() * stride
|
||||
end_idx = (self.process_group.rank() + 1) * stride
|
||||
|
||||
self.forward_layer_type(
|
||||
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||
return a_out
|
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
@ -50,6 +52,16 @@ if SYSTEM == "rocm":
|
|||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
# Constants
|
||||
Q_PROJ = "q_proj"
|
||||
K_PROJ = "k_proj"
|
||||
V_PROJ = "v_proj"
|
||||
O_PROJ = "o_proj"
|
||||
|
||||
GATE_PROJ = "gate_proj"
|
||||
UP_PROJ = "up_proj"
|
||||
DOWN_PROJ = "down_proj"
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
# Only defined in granite.
|
||||
|
@ -57,7 +69,7 @@ def load_attention(config, prefix, weights):
|
|||
|
||||
# if specific model type, load the correct attention
|
||||
if config.model_type == "phi3":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
weights=weights,
|
||||
|
@ -66,7 +78,7 @@ def load_attention(config, prefix, weights):
|
|||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
weights=weights,
|
||||
|
@ -76,7 +88,7 @@ def load_attention(config, prefix, weights):
|
|||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
|
@ -84,6 +96,19 @@ def load_attention(config, prefix, weights):
|
|||
bias=bias,
|
||||
)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer,
|
||||
layer_id,
|
||||
[Q_PROJ, K_PROJ, V_PROJ],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
|
@ -124,7 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
self.adapter_weights = {}
|
||||
adapter_names = list(lora_weights.keys())
|
||||
|
@ -161,12 +186,20 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
pre_multiplied_lora_matrix
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
O_PROJ,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -185,8 +218,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
|
@ -197,32 +231,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
batch_size = query.size(0)
|
||||
|
||||
# hidden states without LoRA
|
||||
hs_wl = hidden_states[lora_indices == -1]
|
||||
|
||||
adapted_query_states = [hs_wl]
|
||||
adapted_value_states = [hs_wl]
|
||||
|
||||
for ind in range(self.n_loras):
|
||||
mask = lora_indices == ind
|
||||
hs_sub = hidden_states[mask]
|
||||
mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0])
|
||||
mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1])
|
||||
adapted_query_states.append(mat_q)
|
||||
adapted_value_states.append(mat_v)
|
||||
|
||||
query_adapted = torch.cat(adapted_query_states, dim=0).view(
|
||||
batch_size, self.num_heads, self.head_size
|
||||
)
|
||||
value_adapted = torch.cat(adapted_value_states, dim=0).view(
|
||||
batch_size, self.num_key_value_heads, self.head_size
|
||||
)
|
||||
|
||||
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
||||
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
@ -260,7 +268,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
|
@ -278,26 +286,46 @@ class LlamaMLP(nn.Module):
|
|||
# Fuse gate and up proj
|
||||
bias = getattr(config, "mlp_bias", False)
|
||||
if config.model_type == "phi3":
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
[GATE_PROJ, UP_PROJ],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
DOWN_PROJ,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
@ -337,7 +365,9 @@ class FlashLlamaLayer(nn.Module):
|
|||
lora_weights=lora_weights,
|
||||
lora_configs=lora_configs,
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
|
@ -362,6 +392,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -378,6 +409,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
|
@ -440,6 +472,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
batch_lora_adapter_mask: Optional[List[str]],
|
||||
lora_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -464,6 +497,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
@ -512,6 +546,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
batch_lora_adapter_mask: Optional[List[str]] = None,
|
||||
lora_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
@ -527,6 +562,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
prefill_cache_indices=prefill_cache_indices,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -13,6 +13,7 @@ from opentelemetry import trace
|
|||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
|||
import text_generation_server.models.globals as tgi_globals
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
empty_cache,
|
||||
|
@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Adapter metadata for each request
|
||||
adapter_meta: AdapterBatchMetadata
|
||||
|
||||
# Number of blocks in this batch
|
||||
num_blocks: int
|
||||
# Maximum number of blocks
|
||||
|
@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
adapter_indices_list = []
|
||||
adapter_set = set()
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
|
@ -225,6 +233,9 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
|
||||
adapter_set.add(r.adapter_index)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
speculative_length = get_speculate()
|
||||
|
@ -296,6 +307,10 @@ class FlashCausalLMBatch(Batch):
|
|||
max_length, input_length + max_new_tokens + speculative_length
|
||||
)
|
||||
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
|
@ -339,6 +354,11 @@ class FlashCausalLMBatch(Batch):
|
|||
input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
|
@ -393,6 +413,12 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
speculative_ids=None,
|
||||
)
|
||||
|
||||
|
@ -443,6 +469,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
adapter_set = set()
|
||||
|
||||
num_blocks = 0
|
||||
max_blocks = 0
|
||||
|
@ -471,6 +498,8 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
adapter_set.add(self.requests[idx].adapter_index)
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
@ -498,6 +527,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
|
@ -513,6 +543,11 @@ class FlashCausalLMBatch(Batch):
|
|||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
|
@ -543,6 +578,12 @@ class FlashCausalLMBatch(Batch):
|
|||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -596,6 +637,14 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_set = set()
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
|
@ -613,6 +662,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_slots = 0
|
||||
cumulative_adapter_indices_size = 0
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
|
@ -637,6 +687,18 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
# Copy over adapter indices
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||
)
|
||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
@ -680,6 +742,8 @@ class FlashCausalLMBatch(Batch):
|
|||
else None
|
||||
)
|
||||
|
||||
_adapter_segments, _adapter_segment_indices = adapter_segment_builder.build()
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -719,6 +783,7 @@ class FlashCausalLMBatch(Batch):
|
|||
class FlashCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_layers: int,
|
||||
|
@ -738,6 +803,7 @@ class FlashCausalLM(Model):
|
|||
self.kv_cache = []
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
|
@ -996,7 +1062,7 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
|
@ -1066,13 +1132,6 @@ class FlashCausalLM(Model):
|
|||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
|
||||
|
||||
for i, r in enumerate(batch.requests):
|
||||
if r.adapter_id:
|
||||
lora_index = self.model.get_lora_index(r.adapter_id)
|
||||
input_length = batch.input_lengths[i]
|
||||
lora_indices[i : i + input_length] = lora_index
|
||||
batch_lora_adapter_mask[i] = True
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -1087,6 +1146,7 @@ class FlashCausalLM(Model):
|
|||
lm_head_indices=lm_head_indices,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
|
@ -1123,7 +1183,34 @@ class FlashCausalLM(Model):
|
|||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
out, speculative_logits = self.forward(batch)
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
adapter_indices = (
|
||||
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||
.expand(B, new_length)
|
||||
.reshape(-1)
|
||||
)
|
||||
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_meta.segment_indices,
|
||||
)
|
||||
|
||||
# Assign pointers to adapter weights
|
||||
# TODO(travis): don't update this if indices haven't changed
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
|
||||
out, speculative_logits = self.forward(batch, adapter_data)
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.distributed
|
|||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
|
@ -22,6 +22,30 @@ tracer = trace.get_tracer(__name__)
|
|||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.lora import LoraConfig
|
||||
|
||||
Q_PROJ = "q_proj"
|
||||
K_PROJ = "k_proj"
|
||||
V_PROJ = "v_proj"
|
||||
O_PROJ = "o_proj"
|
||||
|
||||
GATE_PROJ = "gate_proj"
|
||||
UP_PROJ = "up_proj"
|
||||
DOWN_PROJ = "down_proj"
|
||||
|
||||
LM_HEAD = "lm_head"
|
||||
|
||||
|
||||
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
||||
ADAPTER_LAYERS = [
|
||||
Q_PROJ,
|
||||
K_PROJ,
|
||||
V_PROJ,
|
||||
O_PROJ,
|
||||
GATE_PROJ,
|
||||
UP_PROJ,
|
||||
DOWN_PROJ,
|
||||
] # LM_HEAD
|
||||
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
|
@ -80,6 +104,7 @@ class FlashLlama(FlashCausalLM):
|
|||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
@ -90,3 +115,59 @@ class FlashLlama(FlashCausalLM):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
layer_weights[(i, Q_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, K_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, V_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, O_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
layer_weights[(i, GATE_PROJ)] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, UP_PROJ)] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, DOWN_PROJ)] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return [Q_PROJ, V_PROJ]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
|
|
@ -2,12 +2,50 @@ import inspect
|
|||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
from text_generation_server.utils.adapter import (
|
||||
load_and_merge_adapters,
|
||||
AdapterParameters,
|
||||
AdapterSource,
|
||||
)
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
block_size = size // world_size
|
||||
start = offset + rank * block_size
|
||||
stop = offset + (rank + 1) * block_size
|
||||
return start, stop
|
||||
|
||||
|
||||
def shard_on_dim(
|
||||
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
size = t.shape[dim]
|
||||
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||
|
||||
if dim == 0:
|
||||
tensor = t[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = t[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
@ -15,6 +53,7 @@ B = TypeVar("B", bound=Batch)
|
|||
class Model(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
requires_padding: bool,
|
||||
|
@ -25,6 +64,7 @@ class Model(ABC):
|
|||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
|
@ -42,6 +82,12 @@ class Model(ABC):
|
|||
self.world_size = world_size
|
||||
self.sliding_window = sliding_window if sliding_window != -1 else None
|
||||
|
||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||
LayerAdapterWeights
|
||||
)
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
self.loaded_adapters = set()
|
||||
|
||||
if speculate is None:
|
||||
speculate = get_speculate()
|
||||
self.speculate = speculate
|
||||
|
@ -119,3 +165,156 @@ class Model(ABC):
|
|||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return False
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 0
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
[
|
||||
weights.max_speculative_tokens
|
||||
for weights in self.layer_to_adapter_weights.values()
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
api_token: str,
|
||||
dynamic: bool = True,
|
||||
):
|
||||
"""Loads adapter weights from disk / host memory on the GPU.
|
||||
|
||||
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
||||
into model. Otherwise, the adapter weights are applied during the forward
|
||||
pass and stored separately from the base model parameters.
|
||||
"""
|
||||
if adapter_index in self.loaded_adapters:
|
||||
# Adapter already loaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if dynamic and not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||
)
|
||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_and_merge_adapters(
|
||||
self.model_id,
|
||||
adapter_parameters,
|
||||
adapter_source,
|
||||
adapter_index,
|
||||
weight_names,
|
||||
api_token,
|
||||
False,
|
||||
)
|
||||
|
||||
unused_weight_names = adapter_weight_names.copy()
|
||||
for layer_name in self.adapter_layers:
|
||||
adapter_weights = adapter_config.load_batched_adapter_weights(
|
||||
self,
|
||||
module_map,
|
||||
layer_name,
|
||||
unused_weight_names,
|
||||
dynamic,
|
||||
)
|
||||
|
||||
if adapter_weights is None:
|
||||
continue
|
||||
|
||||
layer_weights = self.layer_to_adapter_weights[layer_name]
|
||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
logger.warning(
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||
|
||||
self.loaded_adapters.add(adapter_index)
|
||||
|
||||
def shard_lora_weights(
|
||||
self,
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
layer_type: str,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
# [hidden_size, r]
|
||||
split_dim = 0 if self.is_row_parallel(layer_type) else 1
|
||||
weights_a = [
|
||||
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
|
||||
for w in weights_a
|
||||
]
|
||||
|
||||
# [r, hidden_size]
|
||||
weights_b = [
|
||||
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
|
||||
]
|
||||
|
||||
return weights_a, weights_b
|
||||
|
||||
def offload_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
):
|
||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
||||
if adapter_index not in self.loaded_adapters:
|
||||
# Adapter already offloaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
for layer_name in self.adapter_layers:
|
||||
if layer_name in self.layer_to_adapter_weights:
|
||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
||||
|
||||
self.loaded_adapters.remove(adapter_index)
|
||||
|
|
|
@ -30,6 +30,9 @@ except (ImportError, NotImplementedError):
|
|||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
)
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
|
@ -235,6 +238,30 @@ def serve(
|
|||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
# TODO: avoid hacky hardcoded adapter id
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=lora_adapter_ids,
|
||||
weights=[
|
||||
# TODO: fill with actual weights
|
||||
torch.tensor([1.0], dtype=torch.float32)
|
||||
],
|
||||
merge_strategy=0,
|
||||
density=0.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_source = None
|
||||
adapter_index = None
|
||||
api_token = None
|
||||
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
adapter_source,
|
||||
adapter_index,
|
||||
api_token,
|
||||
False,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Set, Tuple
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||
|
||||
from text_generation_server.utils import hub
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
class AdapterParameters:
|
||||
def __init__(
|
||||
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
||||
):
|
||||
self.adapter_ids = adapter_ids
|
||||
self.weights = weights
|
||||
self.merge_strategy = merge_strategy
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
|
||||
class AdapterSource:
|
||||
def __init__(self, adapter_id: str, model_id: str, revision: str):
|
||||
self.adapter_id = adapter_id
|
||||
self.model_id = model_id
|
||||
self.revision = revision
|
||||
|
||||
|
||||
def load_and_merge_adapters(
|
||||
model_id: str,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: str,
|
||||
adapter_index: int,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
if len(adapter_parameters.adapter_ids) == 1:
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_parameters.adapter_ids[0],
|
||||
adapter_source,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
|
||||
adapter_params = AdapterParametersContainer(
|
||||
adapter_parameters, adapter_source, adapter_index
|
||||
)
|
||||
return _load_and_merge(
|
||||
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
||||
)
|
||||
|
||||
|
||||
class AdapterParametersContainer:
|
||||
def __init__(self, adapter_parameters, adapter_source, adapter_index):
|
||||
self.adapter_parameters = adapter_parameters
|
||||
self.adapter_source = adapter_source
|
||||
self.adapter_index = adapter_index
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.adapter_index
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _load_and_merge(
|
||||
model_id: str,
|
||||
adapter_params: AdapterParametersContainer,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
params = adapter_params.adapter_parameters
|
||||
|
||||
adapters_to_merge = []
|
||||
merged_weight_names = set()
|
||||
tokenizer = None
|
||||
for adapter_id in params.adapter_ids:
|
||||
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
||||
raise ValueError("Base model adapter cannot be merged.")
|
||||
|
||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
model_id,
|
||||
adapter_id,
|
||||
adapter_params.adapter_source,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
||||
adapters_to_merge.append((module_map, adapter_config))
|
||||
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
||||
if tokenizer is None:
|
||||
tokenizer = adapter_tokenizer
|
||||
|
||||
if len(adapters_to_merge) == 0:
|
||||
raise ValueError("No adapters to merge.")
|
||||
|
||||
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
|
||||
return module_map, adapter_config, merged_weight_names, tokenizer
|
||||
|
||||
|
||||
def check_architectures(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_config: "AdapterConfig",
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
try:
|
||||
if not adapter_config.base_model_name_or_path:
|
||||
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
||||
return
|
||||
|
||||
expected_config = AutoConfig.from_pretrained(
|
||||
model_id, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
|
||||
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
|
||||
)
|
||||
return
|
||||
|
||||
if model_config.architectures == expected_config.architectures:
|
||||
warnings.warn(
|
||||
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
|
||||
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
||||
)
|
||||
else:
|
||||
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
|
||||
raise ValueError(
|
||||
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
|
||||
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
|
||||
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def load_module_map(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_source: str,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
print("adapter_id", adapter_id)
|
||||
|
||||
revision = "main"
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||
|
||||
adapter_filenames = hub._cached_adapter_weight_files(
|
||||
adapter_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
|
||||
try:
|
||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||
adapter_config.config_path,
|
||||
token=api_token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
# Adapter does not have a tokenizer, so fallback to base model tokenizer
|
||||
adapter_tokenizer = None
|
||||
|
||||
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||
adapter_weights = {}
|
||||
for filename in adapter_filenames:
|
||||
adapter_weights.update(load_file(filename))
|
||||
|
||||
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
|
||||
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
|
||||
adapter_weights, weight_names
|
||||
)
|
||||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
|
@ -32,6 +32,7 @@ class LoraConfig:
|
|||
task_type="CAUSAL_LM",
|
||||
use_dora=False,
|
||||
use_rslora=False,
|
||||
config_path=None,
|
||||
):
|
||||
self.alpha_pattern = alpha_pattern or {}
|
||||
self.auto_mapping = auto_mapping
|
||||
|
@ -57,12 +58,13 @@ class LoraConfig:
|
|||
self.task_type = task_type
|
||||
self.use_dora = use_dora
|
||||
self.use_rslora = use_rslora
|
||||
self.config_path = config_path
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename):
|
||||
with open(filename, "r") as f:
|
||||
json_data = json.load(f)
|
||||
return cls(**json_data)
|
||||
return cls(**json_data, config_path=filename)
|
||||
|
||||
# TODO: support fetching the model from the hub if it's not in the cache
|
||||
@classmethod
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
import copy
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AdapterParameters:
|
||||
def __init__(
|
||||
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
||||
):
|
||||
self.adapter_ids = adapter_ids
|
||||
self.weights = weights
|
||||
self.merge_strategy = merge_strategy
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
|
||||
from text_generation_server.utils.merges.utils import (
|
||||
calculate_majority_sign_mask,
|
||||
disjoint_merge,
|
||||
prune,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
from text_generation_server.utils.adapter import ModuleMap
|
||||
|
||||
|
||||
def _apply_weights(
|
||||
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
t = tensors
|
||||
else:
|
||||
t = torch.stack(tensors, dim=0)
|
||||
|
||||
# element-wise weighting of each task tensor
|
||||
# need to unsqueeze weights to match task tensor dimensions
|
||||
# for multiplication to apply element-wise
|
||||
while len(t.shape) > len(w.shape):
|
||||
w = w.unsqueeze(-1)
|
||||
return t * w
|
||||
|
||||
|
||||
class MergeStrategy(ABC):
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LinearMerge(MergeStrategy):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
return weighted_task_tensors.sum(dim=0)
|
||||
|
||||
|
||||
class TiesMerge(MergeStrategy):
|
||||
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# sparsify
|
||||
task_tensors = [
|
||||
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
|
||||
]
|
||||
task_tensors = torch.stack(task_tensors, dim=0)
|
||||
|
||||
# elect sign before applying weights
|
||||
majority_sign_mask = calculate_majority_sign_mask(
|
||||
task_tensors, method=self.majority_sign_method
|
||||
)
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
|
||||
# disjoint merge
|
||||
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
|
||||
|
||||
|
||||
class DareLinearMerge(MergeStrategy):
|
||||
def __init__(self, density: float, **kwargs):
|
||||
self.density = density
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# sparsify
|
||||
task_tensors = [
|
||||
prune(tensor, self.density, method="random", rescale=True)
|
||||
for tensor in task_tensors
|
||||
]
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
return weighted_task_tensors.sum(dim=0)
|
||||
|
||||
|
||||
class DareTiesMerge(MergeStrategy):
|
||||
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# sparsify
|
||||
task_tensors = [
|
||||
prune(tensor, self.density, method="random", rescale=True)
|
||||
for tensor in task_tensors
|
||||
]
|
||||
task_tensors = torch.stack(task_tensors, dim=0)
|
||||
|
||||
# elect sign before applying weights
|
||||
majority_sign_mask = calculate_majority_sign_mask(
|
||||
task_tensors, method=self.majority_sign_method
|
||||
)
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
|
||||
# disjoint merge
|
||||
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
|
||||
return mixed_task_tensors
|
||||
|
||||
|
||||
strategy_registry: Dict[str, Type[MergeStrategy]] = {
|
||||
"linear": LinearMerge,
|
||||
"ties": TiesMerge,
|
||||
"dare_linear": DareLinearMerge,
|
||||
"dare_ties": DareTiesMerge,
|
||||
}
|
||||
|
||||
|
||||
def merge_adapters(
|
||||
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
|
||||
merge_params: AdapterParameters,
|
||||
) -> Tuple["ModuleMap", "LoraConfig"]:
|
||||
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
|
||||
strategy_name = "linear"
|
||||
|
||||
weights = merge_params.weights
|
||||
if not weights:
|
||||
weights = torch.ones(len(adapters))
|
||||
else:
|
||||
weights = torch.tensor(weights)
|
||||
|
||||
merge_config = {
|
||||
"density": merge_params.density,
|
||||
# "majority_sign_method": MajoritySignMethodEnum.Name(
|
||||
# merge_params.majority_sign_method
|
||||
# ).lower(),
|
||||
"majority_sign_method": "total",
|
||||
}
|
||||
merge_strategy = strategy_registry[strategy_name](**merge_config)
|
||||
|
||||
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(list))
|
||||
)
|
||||
lora_configs = []
|
||||
weight_name_to_adapter_idx = defaultdict(list)
|
||||
|
||||
# input is list of (module_map, lora_config) tuples
|
||||
# convert into dict[k][param_name] -> list of tensors
|
||||
for idx, (module_map, lora_config) in enumerate(adapters):
|
||||
for weight_name, data in module_map.items():
|
||||
weight_name_to_adapter_idx[weight_name].append(idx)
|
||||
for k, (param_data, param_name) in data.items():
|
||||
module_maps[weight_name][k][param_name].append(param_data)
|
||||
lora_configs.append(lora_config)
|
||||
|
||||
# validate lora configs are compatible
|
||||
_validate_lora_configs(lora_configs)
|
||||
|
||||
# merge tensors for each module such that we have a single ModuleMap:
|
||||
# dict[k] -> merged tensor
|
||||
merged_module_map: "ModuleMap" = defaultdict(dict)
|
||||
for weight_name, data in module_maps.items():
|
||||
indices = weight_name_to_adapter_idx[weight_name]
|
||||
param_weights = weights[indices]
|
||||
for k, param_data in data.items():
|
||||
for param_name, tensors in param_data.items():
|
||||
merged_tensor = merge_strategy.merge(tensors, param_weights)
|
||||
merged_module_map[weight_name][k] = (merged_tensor, param_name)
|
||||
|
||||
# merge lora configs
|
||||
merged_lora_config = _merge_lora_configs(lora_configs)
|
||||
|
||||
return merged_module_map, merged_lora_config
|
||||
|
||||
|
||||
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
|
||||
# check that all configs have the same rank
|
||||
ranks = set(lora_config.r for lora_config in lora_configs)
|
||||
if len(ranks) > 1:
|
||||
raise ValueError(
|
||||
f"unable to merge adapters, lora configs have different ranks: {ranks}"
|
||||
)
|
||||
|
||||
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
|
||||
raise ValueError(
|
||||
"unable to merge adapters, lora configs have no target modules"
|
||||
)
|
||||
|
||||
|
||||
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
|
||||
merged_lora_config = copy.copy(lora_configs[0])
|
||||
|
||||
# merge target modules as a union operation
|
||||
merged_target_modules = sorted(
|
||||
set(
|
||||
module
|
||||
for lora_config in lora_configs
|
||||
for module in lora_config.target_modules
|
||||
)
|
||||
)
|
||||
merged_lora_config.target_modules = merged_target_modules
|
||||
|
||||
return merged_lora_config
|
|
@ -0,0 +1,108 @@
|
|||
# coding=utf-8
|
||||
# From: https://github.com/huggingface/peft/pull/1364
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
# Modifications by Predibase, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
|
||||
"""
|
||||
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
|
||||
`density`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to prune.
|
||||
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||
"""
|
||||
mask = torch.zeros_like(tensor).reshape(-1)
|
||||
k = int(density * tensor.reshape(-1).shape[0])
|
||||
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
|
||||
mask[top_k[1]] = 1
|
||||
return tensor * mask.reshape(tensor.shape)
|
||||
|
||||
|
||||
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
|
||||
"""
|
||||
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
|
||||
`density`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to prune.
|
||||
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
|
||||
"""
|
||||
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
|
||||
pruned_tensor = tensor * mask
|
||||
if rescale:
|
||||
torch.div(input=pruned_tensor, other=density)
|
||||
return pruned_tensor
|
||||
|
||||
|
||||
def prune(
|
||||
tensor: torch.Tensor,
|
||||
density: float,
|
||||
method: Literal["magnitude", "random"],
|
||||
rescale: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prune the values of task tensors based on the `method`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to prune.
|
||||
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
|
||||
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
|
||||
"""
|
||||
if density >= 1:
|
||||
return tensor
|
||||
elif density < 0:
|
||||
raise ValueError("Density should be >= 0, got {density}")
|
||||
if method == "magnitude":
|
||||
return magnitude_based_pruning(tensor, density)
|
||||
elif method == "random":
|
||||
return random_pruning(tensor, density, rescale=rescale)
|
||||
else:
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
|
||||
|
||||
def calculate_majority_sign_mask(
|
||||
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
|
||||
):
|
||||
"""
|
||||
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to get the mask from.
|
||||
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
|
||||
"""
|
||||
|
||||
sign = tensor.sign()
|
||||
if method == "total":
|
||||
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
|
||||
elif method == "frequency":
|
||||
sign_magnitude = sign.sum(dim=0)
|
||||
else:
|
||||
raise RuntimeError(f'Unimplemented mask method "{method}"')
|
||||
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
|
||||
return sign == majority_sign
|
||||
|
||||
|
||||
def disjoint_merge(task_tensors, majority_sign_mask):
|
||||
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
|
||||
num_params_preserved = majority_sign_mask.sum(dim=0)
|
||||
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
|
|
@ -0,0 +1,62 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find_segments(
|
||||
adapter_indices: Union[torch.Tensor, List[int]]
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
segments = [0]
|
||||
segment_indices = []
|
||||
|
||||
if isinstance(adapter_indices, torch.Tensor):
|
||||
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
|
||||
adapter_indices = adapter_indices.cpu().tolist()
|
||||
|
||||
start_index = 0
|
||||
for i in range(1, len(adapter_indices)):
|
||||
if adapter_indices[i] != adapter_indices[i - 1]:
|
||||
segments.append(i)
|
||||
segment_indices.append(adapter_indices[i - 1])
|
||||
start_index = i
|
||||
|
||||
# Handle the last segment
|
||||
if start_index < len(adapter_indices):
|
||||
segments.append(len(adapter_indices))
|
||||
segment_indices.append(adapter_indices[-1])
|
||||
|
||||
return segments, segment_indices
|
||||
|
||||
|
||||
class SegmentConcatBuilder:
|
||||
def __init__(self):
|
||||
self.adapter_segment_indices = []
|
||||
self.adapter_segment_tensors = []
|
||||
|
||||
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
|
||||
# Update adapter segments
|
||||
if self.adapter_segment_tensors:
|
||||
# Because we have already processed at least one batch, remove the 0 start index
|
||||
# from this batch denoting the beginning of the segment, then offset all segment
|
||||
# positions by the value of the last segment in the previous batch to account for
|
||||
# the concatenation.
|
||||
adapter_segments = (
|
||||
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
|
||||
)
|
||||
|
||||
if (
|
||||
self.adapter_segment_indices
|
||||
and self.adapter_segment_indices[-1] == segment_indices[0]
|
||||
):
|
||||
# If the last segment in the previous batch is the same as the first segment in this batch,
|
||||
# then we merge them together into a single segment. In effect, this means removing it from
|
||||
# the segment indices of this batch, and extending the segment span by removing the segment
|
||||
# end index from the previous batch.
|
||||
segment_indices = segment_indices[1:]
|
||||
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
|
||||
|
||||
self.adapter_segment_indices.extend(segment_indices)
|
||||
self.adapter_segment_tensors.append(adapter_segments)
|
||||
|
||||
def build(self) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
|
|
@ -0,0 +1,242 @@
|
|||
import os
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
# TODO: add build steps for Punica kernels
|
||||
# import punica_kernels as _kernels
|
||||
import punica.ops as _kernels
|
||||
|
||||
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
||||
except ImportError:
|
||||
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
|
||||
_kernels = None
|
||||
HAS_SGMV = False
|
||||
|
||||
|
||||
MIN_SGMV_RANK = 8
|
||||
MIN_RANK_CUSTOM = 16
|
||||
MAX_RANK_CUSTOM = 128
|
||||
SGMV_BLOCK_SIZE = 16
|
||||
BGMV_MAX_RANK = 64
|
||||
|
||||
|
||||
def has_sgmv() -> bool:
|
||||
return HAS_SGMV
|
||||
|
||||
|
||||
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
||||
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
||||
if not has_sgmv():
|
||||
return t
|
||||
|
||||
# tensor parallelism will result in effective rank being divided by world_size,
|
||||
# so we need to scale the min rank to offset that effect
|
||||
min_rank = MIN_SGMV_RANK * world_size
|
||||
|
||||
# if we're at or below the min rank, pad up to the min rank
|
||||
# otherwise, pad to the nearest multiple of the block size
|
||||
current_rank = t.size(dim)
|
||||
target_rank = (
|
||||
min_rank
|
||||
if current_rank <= min_rank
|
||||
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
||||
)
|
||||
if current_rank == target_rank:
|
||||
return t
|
||||
|
||||
pad_size = target_rank - current_rank
|
||||
|
||||
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
||||
pad = [0, 0] * t.dim()
|
||||
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
||||
pad = tuple(pad)
|
||||
|
||||
return F.pad(t, pad, mode="constant", value=0.0)
|
||||
|
||||
|
||||
def use_cutlass_shrink(lora_rank: int) -> bool:
|
||||
return lora_rank < MIN_RANK_CUSTOM
|
||||
|
||||
|
||||
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
||||
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
||||
return t.transpose(0, 1)
|
||||
return t
|
||||
|
||||
|
||||
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
|
||||
def add_lora_sgmv_cutlass(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.Tensor,
|
||||
s_end: torch.Tensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||
Weight matrix shape: `[num_layers, R, H1]`.
|
||||
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||
Weight matrix shape: `[num_layers, R, H2]`.
|
||||
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
||||
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
"""
|
||||
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
||||
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
||||
_add_lora_sgmv_cutlass_legacy(
|
||||
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
|
||||
)
|
||||
return
|
||||
|
||||
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
||||
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
||||
|
||||
|
||||
def _add_lora_sgmv_cutlass_legacy(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
):
|
||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
||||
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
|
||||
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
def get_tmp_expand_size(size: int) -> int:
|
||||
return _kernels.sgmv_cutlass_tmp_size(size)
|
||||
|
||||
|
||||
def get_tmp_tensors(
|
||||
nsegments: int, lora_rank: int, device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if use_cutlass_shrink(lora_rank):
|
||||
tmp = get_tmp_tensor_for_size(nsegments, device)
|
||||
return tmp, tmp
|
||||
else:
|
||||
tmp_shrink = get_tmp_tensor(device)
|
||||
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
||||
return tmp_shrink, tmp_expand
|
||||
|
||||
|
||||
def lora_a_sgmv_cutlass(
|
||||
x: torch.Tensor,
|
||||
tmp: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
) -> torch.Tensor:
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
else:
|
||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
return v
|
||||
|
||||
|
||||
def lora_b_sgmv_cutlass(
|
||||
y: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
tmp: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||
|
||||
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
v: Shape: `[B, R]`. Temporary vector.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
|
||||
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
|
||||
|
||||
def add_lora_a_bgmv(
|
||||
v: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
||||
|
||||
|
||||
def add_lora_b_bgmv(
|
||||
y: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
wb_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
||||
|
||||
|
||||
def segmented_matmul(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w: List[torch.Tensor],
|
||||
b: List[torch.Tensor],
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
):
|
||||
for i in range(len(w)):
|
||||
if s_end[i] - s_start[i] <= 0:
|
||||
continue
|
||||
|
||||
xi = x[s_start[i] : s_end[i]]
|
||||
wi = w[i]
|
||||
bi = b[i]
|
||||
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
|
Loading…
Reference in New Issue