feat: prefer lorax implementation and port loading logic
This commit is contained in:
parent
c661631225
commit
8b50f4b779
|
@ -157,7 +157,7 @@ async fn prefill(
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
adapter_id: None,
|
adapter_index: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
|
@ -107,8 +107,8 @@ message Request {
|
||||||
bool prefill_logprobs = 6;
|
bool prefill_logprobs = 6;
|
||||||
/// Return most likely n tokens
|
/// Return most likely n tokens
|
||||||
uint32 top_n_tokens = 7;
|
uint32 top_n_tokens = 7;
|
||||||
/// LORA adapter id
|
/// LORA adapter index
|
||||||
optional string adapter_id = 8;
|
optional uint32 adapter_index = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -154,7 +154,7 @@ impl Client {
|
||||||
}),
|
}),
|
||||||
prefill_logprobs: true,
|
prefill_logprobs: true,
|
||||||
top_n_tokens: 20,
|
top_n_tokens: 20,
|
||||||
adapter_id: None,
|
adapter_index: None,
|
||||||
});
|
});
|
||||||
n_tokens += max_input_length;
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
|
|
@ -290,7 +290,7 @@ impl State {
|
||||||
entry.request.stopping_parameters.clone(),
|
entry.request.stopping_parameters.clone(),
|
||||||
)),
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_index: entry.request.adapter_index,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
|
@ -430,7 +430,7 @@ mod tests {
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
adapter_id: None,
|
adapter_index: None,
|
||||||
},
|
},
|
||||||
response_tx,
|
response_tx,
|
||||||
span: info_span!("entry"),
|
span: info_span!("entry"),
|
||||||
|
|
|
@ -302,7 +302,7 @@ pub(crate) struct GenerateParameters {
|
||||||
/// Lora adapter id
|
/// Lora adapter id
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub adapter_id: Option<String>,
|
pub adapter_index: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
|
@ -329,7 +329,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
adapter_id: None,
|
adapter_index: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -202,7 +202,7 @@ impl Validation {
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
grammar,
|
grammar,
|
||||||
adapter_id,
|
adapter_index,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ impl Validation {
|
||||||
parameters,
|
parameters,
|
||||||
stopping_parameters,
|
stopping_parameters,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
adapter_id,
|
adapter_index,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -680,7 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
||||||
pub parameters: ValidParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: ValidStoppingParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
pub top_n_tokens: u32,
|
pub top_n_tokens: u32,
|
||||||
pub adapter_id: Option<String>,
|
pub adapter_index: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from text_generation_server.adapters.config import AdapterConfig
|
||||||
|
from text_generation_server.adapters.lora import LoraConfig
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchData,
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_adapter_config(
|
||||||
|
config_path: Optional[Path],
|
||||||
|
adapter_config_path: Optional[Path],
|
||||||
|
api_token: str,
|
||||||
|
) -> AdapterConfig:
|
||||||
|
if adapter_config_path is not None and adapter_config_path.exists():
|
||||||
|
return LoraConfig.load(str(adapter_config_path.parent), api_token)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"No valid adapter config file found: "
|
||||||
|
f"tried {adapter_config_path} and {config_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdapterBatchData",
|
||||||
|
"AdapterBatchMetadata",
|
||||||
|
"load_adapter_config",
|
||||||
|
]
|
|
@ -0,0 +1,37 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from text_generation_server.adapters.weights import AdapterWeights
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
ModuleMap = Dict[str, Dict[str, Tuple[torch.Tensor, str]]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterConfig(ABC):
|
||||||
|
base_model_name_or_path: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict,
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_batched_adapter_weights(
|
||||||
|
self,
|
||||||
|
model: "Model",
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
dynamic: bool,
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
pass
|
|
@ -0,0 +1,430 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig as _LoraConfig
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
|
LORA = "lora"
|
||||||
|
from text_generation_server.adapters.weights import (
|
||||||
|
AdapterBatchMetadata,
|
||||||
|
AdapterWeights,
|
||||||
|
BatchAdapterWeights,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
BGMV_MAX_RANK,
|
||||||
|
MAX_RANK_CUSTOM,
|
||||||
|
get_tmp_tensors,
|
||||||
|
orient_for_rank,
|
||||||
|
pad_rank,
|
||||||
|
use_cutlass_shrink,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.models.model import Model
|
||||||
|
|
||||||
|
EMPTY_TENSOR = torch.tensor([])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraConfig(AdapterConfig):
|
||||||
|
r: int
|
||||||
|
target_modules: Optional[Union[List[str], str]]
|
||||||
|
fan_in_fan_out: bool
|
||||||
|
lora_alpha: int
|
||||||
|
use_rslora: bool
|
||||||
|
|
||||||
|
def map_weights_for_model(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict,
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
) -> Tuple[ModuleMap, Set[str]]:
|
||||||
|
adapter_weight_names = set()
|
||||||
|
module_map = {}
|
||||||
|
for weight_name in weight_names:
|
||||||
|
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||||
|
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||||
|
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_map[weight_name] = {
|
||||||
|
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||||
|
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||||
|
}
|
||||||
|
adapter_weight_names.add(lora_a_name)
|
||||||
|
adapter_weight_names.add(lora_b_name)
|
||||||
|
return module_map, adapter_weight_names
|
||||||
|
|
||||||
|
def load_batched_adapter_weights(
|
||||||
|
self,
|
||||||
|
model: "Model",
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
dynamic: bool,
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
return LoraWeights.load(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
module_map,
|
||||||
|
layer_type,
|
||||||
|
unused_weight_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||||
|
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||||
|
return cls(
|
||||||
|
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||||
|
r=hf_config.r,
|
||||||
|
target_modules=hf_config.target_modules,
|
||||||
|
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||||
|
lora_alpha=hf_config.lora_alpha,
|
||||||
|
use_rslora=(
|
||||||
|
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraWeights(AdapterWeights):
|
||||||
|
"""LoRA weights for a single adapter merged across all layers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
adapter_config: LoraConfig,
|
||||||
|
):
|
||||||
|
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||||
|
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||||
|
|
||||||
|
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||||
|
self._is_transposed = False
|
||||||
|
|
||||||
|
# [num_layers, hidden_size, r]
|
||||||
|
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||||
|
self._weights_a = torch.stack(weights_a)
|
||||||
|
|
||||||
|
# [num_layers, r, hidden_size]
|
||||||
|
self._weights_b = torch.stack(weights_b)
|
||||||
|
|
||||||
|
self.adapter_config = adapter_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b(self) -> torch.Tensor:
|
||||||
|
if self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_a_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_a
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights_b_t(self) -> torch.Tensor:
|
||||||
|
if not self._is_transposed:
|
||||||
|
self._transpose_weights()
|
||||||
|
return self._weights_b
|
||||||
|
|
||||||
|
def _transpose_weights(self):
|
||||||
|
if self._use_cutlass_shrink:
|
||||||
|
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||||
|
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||||
|
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||||
|
self._is_transposed = not self._is_transposed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||||
|
return [BatchLoraWeights]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
config: LoraConfig,
|
||||||
|
model: "Model",
|
||||||
|
module_map: Dict[str, Dict],
|
||||||
|
layer_type: str,
|
||||||
|
unused_weight_names: Set[str],
|
||||||
|
) -> Optional[AdapterWeights]:
|
||||||
|
nlayers = model.get_num_layers_for_type(layer_type)
|
||||||
|
lora_a_list = [None] * nlayers
|
||||||
|
lora_b_list = [None] * nlayers
|
||||||
|
|
||||||
|
for layer_id in range(nlayers):
|
||||||
|
key = (layer_id, layer_type)
|
||||||
|
weight_name, layer = model.target_to_layer[key]
|
||||||
|
base_weight = layer.base_layer.linear.weight
|
||||||
|
base_device = base_weight.device
|
||||||
|
|
||||||
|
if weight_name not in module_map:
|
||||||
|
# There is no LoRA weight for this layer type in the adapter
|
||||||
|
return None
|
||||||
|
|
||||||
|
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||||
|
lora_a = lora_a.to(base_device, model.dtype)
|
||||||
|
|
||||||
|
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||||
|
lora_b = lora_b.to(base_device, model.dtype)
|
||||||
|
|
||||||
|
scale = get_scaling_factor(
|
||||||
|
config.lora_alpha,
|
||||||
|
config.r,
|
||||||
|
uses_rslora=config.use_rslora,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names.discard(lora_a_name)
|
||||||
|
unused_weight_names.discard(lora_b_name)
|
||||||
|
|
||||||
|
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||||
|
# (A * B) * C = A * (B * C)
|
||||||
|
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||||
|
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||||
|
|
||||||
|
# pad lora ranks to be compatible with sgmv
|
||||||
|
lora_a_list = [
|
||||||
|
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||||
|
]
|
||||||
|
lora_b_list = [
|
||||||
|
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||||
|
]
|
||||||
|
|
||||||
|
if lora_a_list:
|
||||||
|
# update rank if it was padded
|
||||||
|
padded_rank = lora_a_list[0].size(1)
|
||||||
|
config.r = padded_rank
|
||||||
|
|
||||||
|
return LoraWeights(
|
||||||
|
*model.shard_lora_weights(lora_a_list, lora_b_list, layer_type),
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankSegments:
|
||||||
|
rank: int
|
||||||
|
|
||||||
|
lora_a_ptr: torch.Tensor
|
||||||
|
lora_b_ptr: torch.Tensor
|
||||||
|
|
||||||
|
# prefill (sgmv)
|
||||||
|
tmp_shrink: torch.Tensor
|
||||||
|
tmp_expand: torch.Tensor
|
||||||
|
segment_starts: torch.Tensor
|
||||||
|
segment_ends: torch.Tensor
|
||||||
|
|
||||||
|
# decode (bgmv)
|
||||||
|
indices: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchLoraWeights(BatchAdapterWeights):
|
||||||
|
lora_a: Dict[int, torch.Tensor]
|
||||||
|
lora_b: Dict[int, torch.Tensor]
|
||||||
|
adapter_index_configs: Dict[int, LoraConfig]
|
||||||
|
rank_data: Dict[int, RankSegments]
|
||||||
|
use_sgmv: bool
|
||||||
|
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
return adapter_index in self.adapter_index_configs
|
||||||
|
|
||||||
|
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||||
|
return all(
|
||||||
|
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||||
|
for rank_data in self.rank_data.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def key(cls) -> str:
|
||||||
|
return LORA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Optional["BatchLoraWeights"]:
|
||||||
|
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||||
|
adapter_weights = {
|
||||||
|
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||||
|
}
|
||||||
|
if not adapter_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_weights = list(adapter_weights.values())[0]
|
||||||
|
device = first_weights.weights_a.device
|
||||||
|
segment_indices = meta.segment_indices
|
||||||
|
|
||||||
|
lora_a = {
|
||||||
|
idx: adapter_weights[idx].weights_a
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
lora_b = {
|
||||||
|
idx: adapter_weights[idx].weights_b
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
max_rank = max(
|
||||||
|
adapter_weights[idx].lora_a_r
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefill or max_rank > BGMV_MAX_RANK:
|
||||||
|
use_sgmv = True
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else EMPTY_TENSOR.data_ptr()
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else EMPTY_TENSOR.data_ptr()
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
use_sgmv = False
|
||||||
|
lora_a_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_a_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else EMPTY_TENSOR.data_ptr()
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
lora_b_ptr = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
adapter_weights[idx].weights_b_t.data_ptr()
|
||||||
|
if idx in adapter_weights
|
||||||
|
else EMPTY_TENSOR.data_ptr()
|
||||||
|
)
|
||||||
|
for idx in segment_indices
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_index_configs = {
|
||||||
|
idx: adapter_weights[idx].adapter_config
|
||||||
|
for idx in segment_indices
|
||||||
|
if idx in adapter_weights
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||||
|
|
||||||
|
rank_indices = defaultdict(list)
|
||||||
|
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||||
|
if adapter_idx not in adapter_weights:
|
||||||
|
continue
|
||||||
|
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||||
|
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||||
|
for head_index in prefill_head_indices:
|
||||||
|
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||||
|
if head_index < meta.adapter_segments[j]:
|
||||||
|
prefill_head_segment_ends[-1] += 1
|
||||||
|
else:
|
||||||
|
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||||
|
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
rank_data = {}
|
||||||
|
for rank, indices in rank_indices.items():
|
||||||
|
tmp_shrink = None
|
||||||
|
tmp_expand = None
|
||||||
|
segment_starts = None
|
||||||
|
segment_ends = None
|
||||||
|
batch_indices = None
|
||||||
|
|
||||||
|
if use_sgmv:
|
||||||
|
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||||
|
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||||
|
lora_a_ptr_indices.size(0), rank, device
|
||||||
|
)
|
||||||
|
segment_starts = meta.adapter_segments[indices]
|
||||||
|
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||||
|
if prefill_head_indices is not None:
|
||||||
|
for i, segment_index in enumerate(indices):
|
||||||
|
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||||
|
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||||
|
else:
|
||||||
|
rank_indices = set(indices)
|
||||||
|
batch_indices = [
|
||||||
|
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||||
|
]
|
||||||
|
batch_indices = [
|
||||||
|
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||||
|
]
|
||||||
|
batch_indices = torch.tensor(
|
||||||
|
batch_indices, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
rank_data[rank] = RankSegments(
|
||||||
|
rank=rank,
|
||||||
|
tmp_shrink=tmp_shrink,
|
||||||
|
tmp_expand=tmp_expand,
|
||||||
|
lora_a_ptr=lora_a_ptr[indices],
|
||||||
|
lora_b_ptr=lora_b_ptr[indices],
|
||||||
|
segment_starts=segment_starts,
|
||||||
|
segment_ends=segment_ends,
|
||||||
|
indices=batch_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchLoraWeights(
|
||||||
|
lora_a=lora_a,
|
||||||
|
lora_b=lora_b,
|
||||||
|
adapter_index_configs=adapter_index_configs,
|
||||||
|
rank_data=rank_data,
|
||||||
|
use_sgmv=use_sgmv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaling_factor(
|
||||||
|
lora_alpha: int,
|
||||||
|
r: int,
|
||||||
|
uses_rslora: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""Computes the scaling factor for the lora weights."""
|
||||||
|
if uses_rslora:
|
||||||
|
return lora_alpha / (r**0.5)
|
||||||
|
return lora_alpha / r
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||||
|
if hasattr(v, "lora_weights"):
|
||||||
|
return v.lora_weights
|
||||||
|
return v
|
|
@ -0,0 +1,159 @@
|
||||||
|
#############
|
||||||
|
from abc import ABC, abstractclassmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
LORA = "lora"
|
||||||
|
LM_HEAD = "lm_head"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchMetadata:
|
||||||
|
# [batch_size]
|
||||||
|
adapter_indices: torch.Tensor
|
||||||
|
|
||||||
|
# [num_adapters]
|
||||||
|
adapter_set: Set[int]
|
||||||
|
|
||||||
|
# [num_segments + 1]
|
||||||
|
adapter_segments: torch.Tensor
|
||||||
|
|
||||||
|
# [num_segments]
|
||||||
|
# maps from segment index to adapter index, i.e.:
|
||||||
|
# segment_indices[s] == adapter_indices[i]
|
||||||
|
segment_indices: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speculative_tokens(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchAdapterWeights(ABC):
|
||||||
|
@abstractclassmethod
|
||||||
|
def has_adapter(self, adapter_index: int) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractclassmethod
|
||||||
|
def key(cls) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractclassmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
adapter_weights: Dict[int, AdapterWeights],
|
||||||
|
meta: "AdapterBatchMetadata",
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: torch.Tensor,
|
||||||
|
) -> Optional["BatchAdapterWeights"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LayerAdapterWeights:
|
||||||
|
"""Adapter weights that apply to a particular layer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||||
|
|
||||||
|
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||||
|
self.adapter_weights[adapter_idx] = weights
|
||||||
|
|
||||||
|
def remove_adapter(self, adapter_idx: int):
|
||||||
|
if adapter_idx not in self.adapter_weights:
|
||||||
|
return
|
||||||
|
del self.adapter_weights[adapter_idx]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_speculative_tokens(self) -> int:
|
||||||
|
return max(
|
||||||
|
adapter_weights.speculative_tokens
|
||||||
|
for adapter_weights in self.adapter_weights.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return len(self.adapter_weights) == 0
|
||||||
|
|
||||||
|
def get_data(
|
||||||
|
self,
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> Dict[str, BatchAdapterWeights]:
|
||||||
|
# bucket adapters by batch class
|
||||||
|
adapter_batch_types: Dict[
|
||||||
|
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||||
|
] = defaultdict(dict)
|
||||||
|
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||||
|
for batch_type in adapter_weights.get_batch_types():
|
||||||
|
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||||
|
|
||||||
|
batch_data = {}
|
||||||
|
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||||
|
batched_weights = batch_type.load(
|
||||||
|
adapter_weights, meta, prefill, prefill_head_indices
|
||||||
|
)
|
||||||
|
if batched_weights is not None:
|
||||||
|
batch_data[batch_type.key()] = batched_weights
|
||||||
|
return batch_data
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdapterBatchData:
|
||||||
|
meta: AdapterBatchMetadata
|
||||||
|
|
||||||
|
# layer type -> adapter type -> batch weight data
|
||||||
|
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||||
|
|
||||||
|
prefill: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_meta(
|
||||||
|
meta: AdapterBatchMetadata,
|
||||||
|
weights: Dict[str, LayerAdapterWeights],
|
||||||
|
prefill: bool,
|
||||||
|
prefill_head_indices: Optional[torch.Tensor],
|
||||||
|
) -> "AdapterBatchData":
|
||||||
|
data = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if v.is_empty():
|
||||||
|
continue
|
||||||
|
data[k] = v.get_data(
|
||||||
|
meta, prefill, prefill_head_indices if k == LM_HEAD else None
|
||||||
|
)
|
||||||
|
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||||
|
|
||||||
|
def ranks(self) -> Set[int]:
|
||||||
|
# TODO(travis): refactor to be less coupled to lora implementation
|
||||||
|
ranks = set()
|
||||||
|
for layer_data in self.data.values():
|
||||||
|
lora_data = layer_data.get(LORA)
|
||||||
|
if lora_data is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for rank_data in lora_data.rank_data.values():
|
||||||
|
ranks.add(rank_data.rank)
|
||||||
|
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
def layer_names(self) -> Set[str]:
|
||||||
|
return set(self.data.keys())
|
||||||
|
|
||||||
|
def adapter_keys(self) -> Set[str]:
|
||||||
|
adapter_keys = set()
|
||||||
|
for layer_data in self.data.values():
|
||||||
|
adapter_keys.update(layer_data.keys())
|
||||||
|
return adapter_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_rank(self) -> int:
|
||||||
|
ranks = self.ranks()
|
||||||
|
return max(ranks) if len(ranks) > 0 else 0
|
|
@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
||||||
# Just to add the `load` methods.
|
# Just to add the `load` methods.
|
||||||
from text_generation_server.layers.layernorm import load_layer_norm
|
from text_generation_server.layers.layernorm import load_layer_norm
|
||||||
from text_generation_server.layers.conv import load_conv2d
|
from text_generation_server.layers.conv import load_conv2d
|
||||||
|
|
||||||
|
from text_generation_server.layers.lora import (
|
||||||
|
LoraLinear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,244 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from text_generation_server.utils.sgmv import (
|
||||||
|
add_lora_a_bgmv,
|
||||||
|
add_lora_b_bgmv,
|
||||||
|
has_sgmv,
|
||||||
|
lora_a_sgmv_cutlass,
|
||||||
|
lora_b_sgmv_cutlass,
|
||||||
|
orient_for_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
LORA = "lora"
|
||||||
|
MEDUSA = "medusa"
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
|
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLinear(nn.Module):
|
||||||
|
def __init__(self, base_layer, layer_id, process_group):
|
||||||
|
super().__init__()
|
||||||
|
self.base_layer = base_layer
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward_layer_type(
|
||||||
|
self,
|
||||||
|
result: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
|
adapter_data: "AdapterBatchData",
|
||||||
|
layer_type: str,
|
||||||
|
start_idx: int,
|
||||||
|
end_idx: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
data = adapter_data.data.get(layer_type)
|
||||||
|
data: Optional["BatchLoraWeights"] = (
|
||||||
|
data.get(LORA) if data is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||||
|
else:
|
||||||
|
proj = result
|
||||||
|
|
||||||
|
for r, rank_segments in data.rank_data.items():
|
||||||
|
lora_a_ptr = rank_segments.lora_a_ptr
|
||||||
|
lora_b_ptr = rank_segments.lora_b_ptr
|
||||||
|
|
||||||
|
if data.use_sgmv:
|
||||||
|
# Use SGMV for prefill
|
||||||
|
if lora_a_ptr is not None and lora_b_ptr is not None:
|
||||||
|
v = lora_a_sgmv_cutlass(
|
||||||
|
input,
|
||||||
|
rank_segments.tmp_shrink,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
r,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
lora_b_sgmv_cutlass(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
rank_segments.tmp_expand,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.segment_starts,
|
||||||
|
rank_segments.segment_ends,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use BGMV for decode
|
||||||
|
if lora_a_ptr is not None and lora_b_ptr is not None:
|
||||||
|
v = torch.zeros(
|
||||||
|
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||||
|
)
|
||||||
|
add_lora_a_bgmv(
|
||||||
|
v,
|
||||||
|
input,
|
||||||
|
lora_a_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
v = self.collect_lora_a(v)
|
||||||
|
|
||||||
|
add_lora_b_bgmv(
|
||||||
|
proj,
|
||||||
|
v,
|
||||||
|
lora_b_ptr,
|
||||||
|
rank_segments.indices,
|
||||||
|
self.layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if end_idx - start_idx != result.shape[1]:
|
||||||
|
result[:, start_idx:end_idx] += proj
|
||||||
|
else:
|
||||||
|
for adapter_index in adapter_data.meta.adapter_set:
|
||||||
|
if data is not None and data.has_adapter(adapter_index):
|
||||||
|
adapter_mask = (
|
||||||
|
(adapter_data.meta.adapter_indices == adapter_index)
|
||||||
|
.to(input.dtype)
|
||||||
|
.view(-1, 1)
|
||||||
|
)
|
||||||
|
layer_result = self.forward_lora(
|
||||||
|
input, data, adapter_index, adapter_mask
|
||||||
|
)
|
||||||
|
result[:, start_idx:end_idx] += layer_result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def forward_lora(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
data: "BatchLoraWeights",
|
||||||
|
adapter_index: int,
|
||||||
|
adapter_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||||
|
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||||
|
|
||||||
|
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||||
|
|
||||||
|
a_out = input @ lora_a
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
a_out = self.collect_lora_a(a_out)
|
||||||
|
|
||||||
|
result = (a_out @ lora_b) * adapter_mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("Implemented in subclasses")
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||||
|
def __init__(self, base_layer, layer_id, layer_names, sizes, process_group):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_names = layer_names
|
||||||
|
self.sizes = sizes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, base_layer, layer_id, layer_names, sizes, process_group):
|
||||||
|
return TensorParallelMultiAdapterLinear(
|
||||||
|
base_layer, layer_id, layer_names, sizes, process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
# handle models like Bloom that have inputs of shape
|
||||||
|
# (batch_size, sequence_length, hidden_size)
|
||||||
|
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||||
|
# for the LoRA computation, then reshape back
|
||||||
|
prev_shape = result.shape
|
||||||
|
is_3d = len(input.shape) >= 3
|
||||||
|
if is_3d:
|
||||||
|
input = input.reshape(-1, input.shape[-1])
|
||||||
|
result = result.reshape(-1, result.shape[-1])
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, layer_name in enumerate(self.layer_names):
|
||||||
|
start_idx = offset // self.process_group.size()
|
||||||
|
|
||||||
|
if self.sizes is not None:
|
||||||
|
offset += self.sizes[i]
|
||||||
|
end_idx = offset // self.process_group.size()
|
||||||
|
else:
|
||||||
|
end_idx = result.shape[1]
|
||||||
|
|
||||||
|
result = self.forward_layer_type(
|
||||||
|
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_3d:
|
||||||
|
result = result.reshape(prev_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||||
|
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
gathered_tensors = [
|
||||||
|
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||||
|
return torch.cat(gathered_tensors, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||||
|
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||||
|
super().__init__(base_layer, layer_id, process_group)
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||||
|
return cls(base_layer, layer_id, layer_name, process_group)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
result = self.base_layer(input)
|
||||||
|
|
||||||
|
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||||
|
stride = result.shape[-1] // self.process_group.size()
|
||||||
|
start_idx = self.process_group.rank() * stride
|
||||||
|
end_idx = (self.process_group.rank() + 1) * stride
|
||||||
|
|
||||||
|
self.forward_layer_type(
|
||||||
|
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||||
|
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||||
|
#
|
||||||
|
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||||
|
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||||
|
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||||
|
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||||
|
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||||
|
return a_out
|
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
|
@ -50,6 +52,16 @@ if SYSTEM == "rocm":
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
Q_PROJ = "q_proj"
|
||||||
|
K_PROJ = "k_proj"
|
||||||
|
V_PROJ = "v_proj"
|
||||||
|
O_PROJ = "o_proj"
|
||||||
|
|
||||||
|
GATE_PROJ = "gate_proj"
|
||||||
|
UP_PROJ = "up_proj"
|
||||||
|
DOWN_PROJ = "down_proj"
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
|
@ -57,7 +69,7 @@ def load_attention(config, prefix, weights):
|
||||||
|
|
||||||
# if specific model type, load the correct attention
|
# if specific model type, load the correct attention
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.qkv_proj",
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -66,7 +78,7 @@ def load_attention(config, prefix, weights):
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
return TensorParallelColumnLinear.load_qkv(
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.W_pack",
|
prefix=f"{prefix}.W_pack",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -76,7 +88,7 @@ def load_attention(config, prefix, weights):
|
||||||
)
|
)
|
||||||
|
|
||||||
# otherwise, load the default attention based on the number of heads
|
# otherwise, load the default attention based on the number of heads
|
||||||
return TensorParallelColumnLinear.load_multi(
|
base_layer = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
dim=0,
|
dim=0,
|
||||||
|
@ -84,6 +96,19 @@ def load_attention(config, prefix, weights):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
return TensorParallelMultiAdapterLinear.load(
|
||||||
|
base_layer,
|
||||||
|
layer_id,
|
||||||
|
[Q_PROJ, K_PROJ, V_PROJ],
|
||||||
|
sizes=[
|
||||||
|
head_size * config.num_attention_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -124,7 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
config.num_key_value_heads // weights.process_group.size()
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.index = index
|
self.index = index
|
||||||
self.adapter_weights = {}
|
self.adapter_weights = {}
|
||||||
adapter_names = list(lora_weights.keys())
|
adapter_names = list(lora_weights.keys())
|
||||||
|
@ -161,12 +186,20 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
pre_multiplied_lora_matrix
|
pre_multiplied_lora_matrix
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
index,
|
||||||
|
O_PROJ,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
@ -185,8 +218,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
batch_lora_adapter_mask,
|
batch_lora_adapter_mask,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
|
@ -197,32 +231,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
batch_size = query.size(0)
|
|
||||||
|
|
||||||
# hidden states without LoRA
|
|
||||||
hs_wl = hidden_states[lora_indices == -1]
|
|
||||||
|
|
||||||
adapted_query_states = [hs_wl]
|
|
||||||
adapted_value_states = [hs_wl]
|
|
||||||
|
|
||||||
for ind in range(self.n_loras):
|
|
||||||
mask = lora_indices == ind
|
|
||||||
hs_sub = hidden_states[mask]
|
|
||||||
mat_q = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 0])
|
|
||||||
mat_v = torch.matmul(hs_sub, self.pre_multiplied_lora_matrices[ind, 1])
|
|
||||||
adapted_query_states.append(mat_q)
|
|
||||||
adapted_value_states.append(mat_v)
|
|
||||||
|
|
||||||
query_adapted = torch.cat(adapted_query_states, dim=0).view(
|
|
||||||
batch_size, self.num_heads, self.head_size
|
|
||||||
)
|
|
||||||
value_adapted = torch.cat(adapted_value_states, dim=0).view(
|
|
||||||
batch_size, self.num_key_value_heads, self.head_size
|
|
||||||
)
|
|
||||||
|
|
||||||
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
|
||||||
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
|
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
@ -260,7 +268,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, index):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
|
@ -278,26 +286,46 @@ class LlamaMLP(nn.Module):
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
bias = getattr(config, "mlp_bias", False)
|
bias = getattr(config, "mlp_bias", False)
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
|
||||||
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
index,
|
||||||
|
[GATE_PROJ, UP_PROJ],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
index,
|
||||||
|
DOWN_PROJ,
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
@ -337,7 +365,9 @@ class FlashLlamaLayer(nn.Module):
|
||||||
lora_weights=lora_weights,
|
lora_weights=lora_weights,
|
||||||
lora_configs=lora_configs,
|
lora_configs=lora_configs,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = LlamaMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
@ -362,6 +392,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
batch_lora_adapter_mask,
|
batch_lora_adapter_mask,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -378,6 +409,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
batch_lora_adapter_mask,
|
batch_lora_adapter_mask,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -440,6 +472,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
batch_lora_adapter_mask: Optional[List[str]],
|
batch_lora_adapter_mask: Optional[List[str]],
|
||||||
lora_indices: Optional[torch.Tensor],
|
lora_indices: Optional[torch.Tensor],
|
||||||
|
adapter_data,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -464,6 +497,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
batch_lora_adapter_mask,
|
batch_lora_adapter_mask,
|
||||||
lora_indices,
|
lora_indices,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
@ -512,6 +546,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
batch_lora_adapter_mask: Optional[List[str]] = None,
|
batch_lora_adapter_mask: Optional[List[str]] = None,
|
||||||
lora_indices: Optional[torch.Tensor] = None,
|
lora_indices: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -527,6 +562,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
lora_indices=lora_indices,
|
lora_indices=lora_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -13,6 +13,7 @@ from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||||
import text_generation_server.models.globals as tgi_globals
|
import text_generation_server.models.globals as tgi_globals
|
||||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||||
|
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
empty_cache,
|
empty_cache,
|
||||||
|
@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens: List[int]
|
top_n_tokens: List[int]
|
||||||
top_n_tokens_tensor: torch.Tensor
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
|
# Adapter metadata for each request
|
||||||
|
adapter_meta: AdapterBatchMetadata
|
||||||
|
|
||||||
# Number of blocks in this batch
|
# Number of blocks in this batch
|
||||||
num_blocks: int
|
num_blocks: int
|
||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
|
@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
|
||||||
|
adapter_indices_list = []
|
||||||
|
adapter_set = set()
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
|
@ -225,6 +233,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
|
||||||
|
adapter_indices_list.append(torch.full((input_length,), r.adapter_index))
|
||||||
|
adapter_set.add(r.adapter_index)
|
||||||
|
|
||||||
# Paged attention
|
# Paged attention
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
|
@ -296,6 +307,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
max_length, input_length + max_new_tokens + speculative_length
|
max_length, input_length + max_new_tokens + speculative_length
|
||||||
)
|
)
|
||||||
|
|
||||||
|
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||||
|
dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -339,6 +354,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
input_lengths, dtype=torch.int32, device=device
|
input_lengths, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
|
adapter_segments = torch.tensor(
|
||||||
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||||
|
@ -393,6 +413,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
adapter_meta=AdapterBatchMetadata(
|
||||||
|
adapter_indices=adapter_indices,
|
||||||
|
adapter_set=adapter_set,
|
||||||
|
adapter_segments=adapter_segments,
|
||||||
|
segment_indices=adapter_segment_indices,
|
||||||
|
),
|
||||||
speculative_ids=None,
|
speculative_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -443,6 +469,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
adapter_set = set()
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
@ -471,6 +498,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
top_n_tokens.append(self.top_n_tokens[idx])
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
|
||||||
|
adapter_set.add(self.requests[idx].adapter_index)
|
||||||
|
|
||||||
remaining_tokens = (
|
remaining_tokens = (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
@ -498,6 +527,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
|
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
|
@ -513,6 +543,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
|
||||||
|
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||||
|
adapter_segments = torch.tensor(
|
||||||
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
return type(self)(
|
return type(self)(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
@ -543,6 +578,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
|
adapter_meta=AdapterBatchMetadata(
|
||||||
|
adapter_indices=adapter_indices,
|
||||||
|
adapter_set=adapter_set,
|
||||||
|
adapter_segments=adapter_segments,
|
||||||
|
segment_indices=adapter_segment_indices,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -596,6 +637,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
)
|
)
|
||||||
|
total_indices_size = sum(
|
||||||
|
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||||
|
)
|
||||||
|
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||||
|
total_indices_size
|
||||||
|
)
|
||||||
|
adapter_set = set()
|
||||||
|
adapter_segment_builder = SegmentConcatBuilder()
|
||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
@ -613,6 +662,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
cumulative_slots = 0
|
cumulative_slots = 0
|
||||||
|
cumulative_adapter_indices_size = 0
|
||||||
|
|
||||||
for i, batch in enumerate(batches):
|
for i, batch in enumerate(batches):
|
||||||
requests.extend(batch.requests)
|
requests.extend(batch.requests)
|
||||||
|
@ -637,6 +687,18 @@ class FlashCausalLMBatch(Batch):
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
|
# Copy over adapter indices
|
||||||
|
adapter_start_index = cumulative_adapter_indices_size
|
||||||
|
adapter_end_index = (
|
||||||
|
cumulative_adapter_indices_size
|
||||||
|
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||||
|
)
|
||||||
|
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||||
|
batch.adapter_meta.adapter_indices
|
||||||
|
)
|
||||||
|
cumulative_adapter_indices_size = adapter_end_index
|
||||||
|
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||||
|
|
||||||
all_input_ids_tensor[
|
all_input_ids_tensor[
|
||||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||||
] = batch.all_input_ids_tensor[:, :max_length]
|
] = batch.all_input_ids_tensor[:, :max_length]
|
||||||
|
@ -680,6 +742,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_adapter_segments, _adapter_segment_indices = adapter_segment_builder.build()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
|
@ -719,6 +783,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model_id: str,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
|
@ -738,6 +803,7 @@ class FlashCausalLM(Model):
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=False,
|
requires_padding=False,
|
||||||
|
@ -996,7 +1062,7 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
|
@ -1066,13 +1132,6 @@ class FlashCausalLM(Model):
|
||||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||||
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
|
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
for i, r in enumerate(batch.requests):
|
|
||||||
if r.adapter_id:
|
|
||||||
lora_index = self.model.get_lora_index(r.adapter_id)
|
|
||||||
input_length = batch.input_lengths[i]
|
|
||||||
lora_indices[i : i + input_length] = lora_index
|
|
||||||
batch_lora_adapter_mask[i] = True
|
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -1087,6 +1146,7 @@ class FlashCausalLM(Model):
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||||
lora_indices=lora_indices,
|
lora_indices=lora_indices,
|
||||||
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
|
@ -1123,7 +1183,34 @@ class FlashCausalLM(Model):
|
||||||
prefill = batch.cu_seqlen_prefill is not None
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
out, speculative_logits = self.forward(batch)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
|
adapter_meta = batch.adapter_meta
|
||||||
|
if batch.speculative_ids is not None:
|
||||||
|
B, speculative_length = batch.speculative_ids.shape
|
||||||
|
new_length = speculative_length + 1
|
||||||
|
adapter_indices = (
|
||||||
|
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||||
|
.expand(B, new_length)
|
||||||
|
.reshape(-1)
|
||||||
|
)
|
||||||
|
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||||
|
adapter_meta = AdapterBatchMetadata(
|
||||||
|
adapter_indices=adapter_indices,
|
||||||
|
adapter_set=adapter_meta.adapter_set,
|
||||||
|
adapter_segments=adapter_segments,
|
||||||
|
segment_indices=adapter_meta.segment_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign pointers to adapter weights
|
||||||
|
# TODO(travis): don't update this if indices haven't changed
|
||||||
|
adapter_data = AdapterBatchData.from_meta(
|
||||||
|
adapter_meta,
|
||||||
|
self.layer_to_adapter_weights,
|
||||||
|
prefill,
|
||||||
|
batch.prefill_head_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
out, speculative_logits = self.forward(batch, adapter_data)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch.distributed
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||||
from typing import Optional
|
from typing import Optional, Tuple, Dict, List
|
||||||
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||||
|
@ -22,6 +22,30 @@ tracer = trace.get_tracer(__name__)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.lora import LoraConfig
|
from text_generation_server.utils.lora import LoraConfig
|
||||||
|
|
||||||
|
Q_PROJ = "q_proj"
|
||||||
|
K_PROJ = "k_proj"
|
||||||
|
V_PROJ = "v_proj"
|
||||||
|
O_PROJ = "o_proj"
|
||||||
|
|
||||||
|
GATE_PROJ = "gate_proj"
|
||||||
|
UP_PROJ = "up_proj"
|
||||||
|
DOWN_PROJ = "down_proj"
|
||||||
|
|
||||||
|
LM_HEAD = "lm_head"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
||||||
|
ADAPTER_LAYERS = [
|
||||||
|
Q_PROJ,
|
||||||
|
K_PROJ,
|
||||||
|
V_PROJ,
|
||||||
|
O_PROJ,
|
||||||
|
GATE_PROJ,
|
||||||
|
UP_PROJ,
|
||||||
|
DOWN_PROJ,
|
||||||
|
] # LM_HEAD
|
||||||
|
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -80,6 +104,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
)
|
)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
|
@ -90,3 +115,59 @@ class FlashLlama(FlashCausalLM):
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_adapter_loading(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||||
|
layer_weights = {}
|
||||||
|
|
||||||
|
prefix = "model.layers"
|
||||||
|
for i, layer in enumerate(self.model.model.layers):
|
||||||
|
layer_weights[(i, Q_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, K_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.self_attn.k_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, V_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.self_attn.v_proj",
|
||||||
|
layer.self_attn.query_key_value,
|
||||||
|
)
|
||||||
|
layer_weights[(i, O_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.self_attn.o_proj",
|
||||||
|
layer.self_attn.o_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_weights[(i, GATE_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.mlp.gate_proj",
|
||||||
|
layer.mlp.gate_up_proj,
|
||||||
|
)
|
||||||
|
layer_weights[(i, UP_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.mlp.up_proj",
|
||||||
|
layer.mlp.gate_up_proj,
|
||||||
|
)
|
||||||
|
layer_weights[(i, DOWN_PROJ)] = (
|
||||||
|
f"{prefix}.{i}.mlp.down_proj",
|
||||||
|
layer.mlp.down_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
||||||
|
return layer_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def adapter_layers(self) -> List[str]:
|
||||||
|
return ADAPTER_LAYERS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_traced_adapter_layers(self) -> List[str]:
|
||||||
|
return [Q_PROJ, V_PROJ]
|
||||||
|
|
||||||
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
|
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||||
|
|
||||||
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
|
return layer_type in ROW_PARALLEL
|
||||||
|
|
|
@ -2,12 +2,50 @@ import inspect
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||||
|
from collections import defaultdict
|
||||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
|
from text_generation_server.utils.adapter import (
|
||||||
|
load_and_merge_adapters,
|
||||||
|
AdapterParameters,
|
||||||
|
AdapterSource,
|
||||||
|
)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
|
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||||
|
block_size = size // world_size
|
||||||
|
start = offset + rank * block_size
|
||||||
|
stop = offset + (rank + 1) * block_size
|
||||||
|
return start, stop
|
||||||
|
|
||||||
|
|
||||||
|
def shard_on_dim(
|
||||||
|
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||||
|
):
|
||||||
|
world_size = process_group.size()
|
||||||
|
rank = process_group.rank()
|
||||||
|
|
||||||
|
size = t.shape[dim]
|
||||||
|
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||||
|
|
||||||
|
if dim == 0:
|
||||||
|
tensor = t[start:stop]
|
||||||
|
elif dim == 1:
|
||||||
|
tensor = t[:, start:stop]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
|
@ -15,6 +53,7 @@ B = TypeVar("B", bound=Batch)
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model_id: str,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
|
@ -25,6 +64,7 @@ class Model(ABC):
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
self.model_id = model_id
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@ -42,6 +82,12 @@ class Model(ABC):
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.sliding_window = sliding_window if sliding_window != -1 else None
|
self.sliding_window = sliding_window if sliding_window != -1 else None
|
||||||
|
|
||||||
|
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||||
|
LayerAdapterWeights
|
||||||
|
)
|
||||||
|
self.target_to_layer = self.adapter_target_to_layer()
|
||||||
|
self.loaded_adapters = set()
|
||||||
|
|
||||||
if speculate is None:
|
if speculate is None:
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
self.speculate = speculate
|
self.speculate = speculate
|
||||||
|
@ -119,3 +165,156 @@ class Model(ABC):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_adapter_loading(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def adapter_layers(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_traced_adapter_layers(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_speculative_tokens(self) -> int:
|
||||||
|
return max(
|
||||||
|
[
|
||||||
|
weights.max_speculative_tokens
|
||||||
|
for weights in self.layer_to_adapter_weights.values()
|
||||||
|
],
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_adapter(
|
||||||
|
self,
|
||||||
|
adapter_parameters: AdapterParameters,
|
||||||
|
adapter_source: AdapterSource,
|
||||||
|
adapter_index: int,
|
||||||
|
api_token: str,
|
||||||
|
dynamic: bool = True,
|
||||||
|
):
|
||||||
|
"""Loads adapter weights from disk / host memory on the GPU.
|
||||||
|
|
||||||
|
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
||||||
|
into model. Otherwise, the adapter weights are applied during the forward
|
||||||
|
pass and stored separately from the base model parameters.
|
||||||
|
"""
|
||||||
|
if adapter_index in self.loaded_adapters:
|
||||||
|
# Adapter already loaded
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.supports_adapter_loading:
|
||||||
|
raise ValueError("This model does not support adapter loading.")
|
||||||
|
|
||||||
|
if dynamic and not self.dynamic_adapter_loading_enabled:
|
||||||
|
raise ValueError(
|
||||||
|
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||||
|
f"and therefore does not support dynamic adapter loading. "
|
||||||
|
f"Please initialize a new model instance from the base model in "
|
||||||
|
f"order to use the dynamic adapter loading feature."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||||
|
)
|
||||||
|
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||||
|
(
|
||||||
|
module_map,
|
||||||
|
adapter_config,
|
||||||
|
adapter_weight_names,
|
||||||
|
adapter_tokenizer,
|
||||||
|
) = load_and_merge_adapters(
|
||||||
|
self.model_id,
|
||||||
|
adapter_parameters,
|
||||||
|
adapter_source,
|
||||||
|
adapter_index,
|
||||||
|
weight_names,
|
||||||
|
api_token,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
unused_weight_names = adapter_weight_names.copy()
|
||||||
|
for layer_name in self.adapter_layers:
|
||||||
|
adapter_weights = adapter_config.load_batched_adapter_weights(
|
||||||
|
self,
|
||||||
|
module_map,
|
||||||
|
layer_name,
|
||||||
|
unused_weight_names,
|
||||||
|
dynamic,
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_weights is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer_weights = self.layer_to_adapter_weights[layer_name]
|
||||||
|
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||||
|
|
||||||
|
if len(unused_weight_names) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter_tokenizer is not None:
|
||||||
|
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||||
|
|
||||||
|
self.loaded_adapters.add(adapter_index)
|
||||||
|
|
||||||
|
def shard_lora_weights(
|
||||||
|
self,
|
||||||
|
weights_a: List[torch.Tensor],
|
||||||
|
weights_b: List[torch.Tensor],
|
||||||
|
layer_type: str,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
# [hidden_size, r]
|
||||||
|
split_dim = 0 if self.is_row_parallel(layer_type) else 1
|
||||||
|
weights_a = [
|
||||||
|
shard_on_dim(w, dim=split_dim, process_group=self.process_group)
|
||||||
|
for w in weights_a
|
||||||
|
]
|
||||||
|
|
||||||
|
# [r, hidden_size]
|
||||||
|
weights_b = [
|
||||||
|
shard_on_dim(w, dim=1, process_group=self.process_group) for w in weights_b
|
||||||
|
]
|
||||||
|
|
||||||
|
return weights_a, weights_b
|
||||||
|
|
||||||
|
def offload_adapter(
|
||||||
|
self,
|
||||||
|
adapter_parameters: AdapterParameters,
|
||||||
|
adapter_source: AdapterSource,
|
||||||
|
adapter_index: int,
|
||||||
|
):
|
||||||
|
"""Offloads the adapter weights from GPU to CPU or disk."""
|
||||||
|
if adapter_index not in self.loaded_adapters:
|
||||||
|
# Adapter already offloaded
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.supports_adapter_loading:
|
||||||
|
raise ValueError("This model does not support adapter loading.")
|
||||||
|
|
||||||
|
if not self.dynamic_adapter_loading_enabled:
|
||||||
|
raise ValueError(
|
||||||
|
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||||
|
f"and therefore does not support dynamic adapter loading. "
|
||||||
|
f"Please initialize a new model instance from the base model in "
|
||||||
|
f"order to use the dynamic adapter loading feature."
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_name in self.adapter_layers:
|
||||||
|
if layer_name in self.layer_to_adapter_weights:
|
||||||
|
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
||||||
|
|
||||||
|
self.loaded_adapters.remove(adapter_index)
|
||||||
|
|
|
@ -30,6 +30,9 @@ except (ImportError, NotImplementedError):
|
||||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.globals import set_model_id
|
from text_generation_server.models.globals import set_model_id
|
||||||
|
from text_generation_server.utils.adapter import (
|
||||||
|
AdapterParameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
|
@ -235,6 +238,30 @@ def serve(
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: avoid hacky hardcoded adapter id
|
||||||
|
adapter_parameters = AdapterParameters(
|
||||||
|
adapter_ids=lora_adapter_ids,
|
||||||
|
weights=[
|
||||||
|
# TODO: fill with actual weights
|
||||||
|
torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
],
|
||||||
|
merge_strategy=0,
|
||||||
|
density=0.0,
|
||||||
|
majority_sign_method=0,
|
||||||
|
)
|
||||||
|
adapter_source = None
|
||||||
|
adapter_index = None
|
||||||
|
api_token = None
|
||||||
|
|
||||||
|
model.load_adapter(
|
||||||
|
adapter_parameters,
|
||||||
|
adapter_source,
|
||||||
|
adapter_index,
|
||||||
|
api_token,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import TYPE_CHECKING, Set, Tuple
|
||||||
|
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||||
|
|
||||||
|
from text_generation_server.utils import hub
|
||||||
|
from text_generation_server.adapters.lora import LoraConfig
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||||
|
|
||||||
|
|
||||||
|
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterParameters:
|
||||||
|
def __init__(
|
||||||
|
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
||||||
|
):
|
||||||
|
self.adapter_ids = adapter_ids
|
||||||
|
self.weights = weights
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
self.density = density
|
||||||
|
self.majority_sign_method = majority_sign_method
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterSource:
|
||||||
|
def __init__(self, adapter_id: str, model_id: str, revision: str):
|
||||||
|
self.adapter_id = adapter_id
|
||||||
|
self.model_id = model_id
|
||||||
|
self.revision = revision
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_merge_adapters(
|
||||||
|
model_id: str,
|
||||||
|
adapter_parameters: AdapterParameters,
|
||||||
|
adapter_source: str,
|
||||||
|
adapter_index: int,
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
api_token: str,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
if len(adapter_parameters.adapter_ids) == 1:
|
||||||
|
return load_module_map(
|
||||||
|
model_id,
|
||||||
|
adapter_parameters.adapter_ids[0],
|
||||||
|
adapter_source,
|
||||||
|
weight_names,
|
||||||
|
api_token,
|
||||||
|
trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter_params = AdapterParametersContainer(
|
||||||
|
adapter_parameters, adapter_source, adapter_index
|
||||||
|
)
|
||||||
|
return _load_and_merge(
|
||||||
|
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterParametersContainer:
|
||||||
|
def __init__(self, adapter_parameters, adapter_source, adapter_index):
|
||||||
|
self.adapter_parameters = adapter_parameters
|
||||||
|
self.adapter_source = adapter_source
|
||||||
|
self.adapter_index = adapter_index
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return self.adapter_index
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def _load_and_merge(
|
||||||
|
model_id: str,
|
||||||
|
adapter_params: AdapterParametersContainer,
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
api_token: str,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
params = adapter_params.adapter_parameters
|
||||||
|
|
||||||
|
adapters_to_merge = []
|
||||||
|
merged_weight_names = set()
|
||||||
|
tokenizer = None
|
||||||
|
for adapter_id in params.adapter_ids:
|
||||||
|
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
||||||
|
raise ValueError("Base model adapter cannot be merged.")
|
||||||
|
|
||||||
|
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||||
|
load_module_map(
|
||||||
|
model_id,
|
||||||
|
adapter_id,
|
||||||
|
adapter_params.adapter_source,
|
||||||
|
weight_names,
|
||||||
|
api_token,
|
||||||
|
trust_remote_code,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
adapters_to_merge.append((module_map, adapter_config))
|
||||||
|
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = adapter_tokenizer
|
||||||
|
|
||||||
|
if len(adapters_to_merge) == 0:
|
||||||
|
raise ValueError("No adapters to merge.")
|
||||||
|
|
||||||
|
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
|
||||||
|
return module_map, adapter_config, merged_weight_names, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def check_architectures(
|
||||||
|
model_id: str,
|
||||||
|
adapter_id: str,
|
||||||
|
adapter_config: "AdapterConfig",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if not adapter_config.base_model_name_or_path:
|
||||||
|
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
||||||
|
return
|
||||||
|
|
||||||
|
expected_config = AutoConfig.from_pretrained(
|
||||||
|
model_id, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
model_config = AutoConfig.from_pretrained(
|
||||||
|
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
|
||||||
|
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if model_config.architectures == expected_config.architectures:
|
||||||
|
warnings.warn(
|
||||||
|
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
|
||||||
|
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
|
||||||
|
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
|
||||||
|
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def load_module_map(
|
||||||
|
model_id: str,
|
||||||
|
adapter_id: str,
|
||||||
|
adapter_source: str,
|
||||||
|
weight_names: Tuple[str],
|
||||||
|
api_token: str,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
print("adapter_id", adapter_id)
|
||||||
|
|
||||||
|
revision = "main"
|
||||||
|
|
||||||
|
adapter_config = LoraConfig.load(adapter_id, api_token)
|
||||||
|
if adapter_config.base_model_name_or_path != model_id:
|
||||||
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||||
|
|
||||||
|
adapter_filenames = hub._cached_adapter_weight_files(
|
||||||
|
adapter_id, revision=revision, extension=".safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
adapter_config.config_path,
|
||||||
|
token=api_token,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Adapter does not have a tokenizer, so fallback to base model tokenizer
|
||||||
|
adapter_tokenizer = None
|
||||||
|
|
||||||
|
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||||
|
adapter_weights = {}
|
||||||
|
for filename in adapter_filenames:
|
||||||
|
adapter_weights.update(load_file(filename))
|
||||||
|
|
||||||
|
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
|
||||||
|
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
|
||||||
|
adapter_weights, weight_names
|
||||||
|
)
|
||||||
|
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
|
@ -32,6 +32,7 @@ class LoraConfig:
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
use_dora=False,
|
use_dora=False,
|
||||||
use_rslora=False,
|
use_rslora=False,
|
||||||
|
config_path=None,
|
||||||
):
|
):
|
||||||
self.alpha_pattern = alpha_pattern or {}
|
self.alpha_pattern = alpha_pattern or {}
|
||||||
self.auto_mapping = auto_mapping
|
self.auto_mapping = auto_mapping
|
||||||
|
@ -57,12 +58,13 @@ class LoraConfig:
|
||||||
self.task_type = task_type
|
self.task_type = task_type
|
||||||
self.use_dora = use_dora
|
self.use_dora = use_dora
|
||||||
self.use_rslora = use_rslora
|
self.use_rslora = use_rslora
|
||||||
|
self.config_path = config_path
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, filename):
|
def from_file(cls, filename):
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
json_data = json.load(f)
|
json_data = json.load(f)
|
||||||
return cls(**json_data)
|
return cls(**json_data, config_path=filename)
|
||||||
|
|
||||||
# TODO: support fetching the model from the hub if it's not in the cache
|
# TODO: support fetching the model from the hub if it's not in the cache
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -0,0 +1,223 @@
|
||||||
|
import copy
|
||||||
|
from abc import ABC
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterParameters:
|
||||||
|
def __init__(
|
||||||
|
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
||||||
|
):
|
||||||
|
self.adapter_ids = adapter_ids
|
||||||
|
self.weights = weights
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
self.density = density
|
||||||
|
self.majority_sign_method = majority_sign_method
|
||||||
|
|
||||||
|
|
||||||
|
from text_generation_server.utils.merges.utils import (
|
||||||
|
calculate_majority_sign_mask,
|
||||||
|
disjoint_merge,
|
||||||
|
prune,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from text_generation_server.adapters.lora import LoraConfig
|
||||||
|
from text_generation_server.utils.adapter import ModuleMap
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_weights(
|
||||||
|
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
t = tensors
|
||||||
|
else:
|
||||||
|
t = torch.stack(tensors, dim=0)
|
||||||
|
|
||||||
|
# element-wise weighting of each task tensor
|
||||||
|
# need to unsqueeze weights to match task tensor dimensions
|
||||||
|
# for multiplication to apply element-wise
|
||||||
|
while len(t.shape) > len(w.shape):
|
||||||
|
w = w.unsqueeze(-1)
|
||||||
|
return t * w
|
||||||
|
|
||||||
|
|
||||||
|
class MergeStrategy(ABC):
|
||||||
|
def merge(
|
||||||
|
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class LinearMerge(MergeStrategy):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||||
|
return weighted_task_tensors.sum(dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class TiesMerge(MergeStrategy):
|
||||||
|
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
|
||||||
|
self.density = density
|
||||||
|
self.majority_sign_method = majority_sign_method
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# sparsify
|
||||||
|
task_tensors = [
|
||||||
|
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
|
||||||
|
]
|
||||||
|
task_tensors = torch.stack(task_tensors, dim=0)
|
||||||
|
|
||||||
|
# elect sign before applying weights
|
||||||
|
majority_sign_mask = calculate_majority_sign_mask(
|
||||||
|
task_tensors, method=self.majority_sign_method
|
||||||
|
)
|
||||||
|
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||||
|
|
||||||
|
# disjoint merge
|
||||||
|
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
|
||||||
|
|
||||||
|
|
||||||
|
class DareLinearMerge(MergeStrategy):
|
||||||
|
def __init__(self, density: float, **kwargs):
|
||||||
|
self.density = density
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# sparsify
|
||||||
|
task_tensors = [
|
||||||
|
prune(tensor, self.density, method="random", rescale=True)
|
||||||
|
for tensor in task_tensors
|
||||||
|
]
|
||||||
|
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||||
|
return weighted_task_tensors.sum(dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DareTiesMerge(MergeStrategy):
|
||||||
|
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
|
||||||
|
self.density = density
|
||||||
|
self.majority_sign_method = majority_sign_method
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# sparsify
|
||||||
|
task_tensors = [
|
||||||
|
prune(tensor, self.density, method="random", rescale=True)
|
||||||
|
for tensor in task_tensors
|
||||||
|
]
|
||||||
|
task_tensors = torch.stack(task_tensors, dim=0)
|
||||||
|
|
||||||
|
# elect sign before applying weights
|
||||||
|
majority_sign_mask = calculate_majority_sign_mask(
|
||||||
|
task_tensors, method=self.majority_sign_method
|
||||||
|
)
|
||||||
|
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||||
|
|
||||||
|
# disjoint merge
|
||||||
|
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
|
||||||
|
return mixed_task_tensors
|
||||||
|
|
||||||
|
|
||||||
|
strategy_registry: Dict[str, Type[MergeStrategy]] = {
|
||||||
|
"linear": LinearMerge,
|
||||||
|
"ties": TiesMerge,
|
||||||
|
"dare_linear": DareLinearMerge,
|
||||||
|
"dare_ties": DareTiesMerge,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def merge_adapters(
|
||||||
|
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
|
||||||
|
merge_params: AdapterParameters,
|
||||||
|
) -> Tuple["ModuleMap", "LoraConfig"]:
|
||||||
|
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
|
||||||
|
strategy_name = "linear"
|
||||||
|
|
||||||
|
weights = merge_params.weights
|
||||||
|
if not weights:
|
||||||
|
weights = torch.ones(len(adapters))
|
||||||
|
else:
|
||||||
|
weights = torch.tensor(weights)
|
||||||
|
|
||||||
|
merge_config = {
|
||||||
|
"density": merge_params.density,
|
||||||
|
# "majority_sign_method": MajoritySignMethodEnum.Name(
|
||||||
|
# merge_params.majority_sign_method
|
||||||
|
# ).lower(),
|
||||||
|
"majority_sign_method": "total",
|
||||||
|
}
|
||||||
|
merge_strategy = strategy_registry[strategy_name](**merge_config)
|
||||||
|
|
||||||
|
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
|
||||||
|
lambda: defaultdict(lambda: defaultdict(list))
|
||||||
|
)
|
||||||
|
lora_configs = []
|
||||||
|
weight_name_to_adapter_idx = defaultdict(list)
|
||||||
|
|
||||||
|
# input is list of (module_map, lora_config) tuples
|
||||||
|
# convert into dict[k][param_name] -> list of tensors
|
||||||
|
for idx, (module_map, lora_config) in enumerate(adapters):
|
||||||
|
for weight_name, data in module_map.items():
|
||||||
|
weight_name_to_adapter_idx[weight_name].append(idx)
|
||||||
|
for k, (param_data, param_name) in data.items():
|
||||||
|
module_maps[weight_name][k][param_name].append(param_data)
|
||||||
|
lora_configs.append(lora_config)
|
||||||
|
|
||||||
|
# validate lora configs are compatible
|
||||||
|
_validate_lora_configs(lora_configs)
|
||||||
|
|
||||||
|
# merge tensors for each module such that we have a single ModuleMap:
|
||||||
|
# dict[k] -> merged tensor
|
||||||
|
merged_module_map: "ModuleMap" = defaultdict(dict)
|
||||||
|
for weight_name, data in module_maps.items():
|
||||||
|
indices = weight_name_to_adapter_idx[weight_name]
|
||||||
|
param_weights = weights[indices]
|
||||||
|
for k, param_data in data.items():
|
||||||
|
for param_name, tensors in param_data.items():
|
||||||
|
merged_tensor = merge_strategy.merge(tensors, param_weights)
|
||||||
|
merged_module_map[weight_name][k] = (merged_tensor, param_name)
|
||||||
|
|
||||||
|
# merge lora configs
|
||||||
|
merged_lora_config = _merge_lora_configs(lora_configs)
|
||||||
|
|
||||||
|
return merged_module_map, merged_lora_config
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
|
||||||
|
# check that all configs have the same rank
|
||||||
|
ranks = set(lora_config.r for lora_config in lora_configs)
|
||||||
|
if len(ranks) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"unable to merge adapters, lora configs have different ranks: {ranks}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
|
||||||
|
raise ValueError(
|
||||||
|
"unable to merge adapters, lora configs have no target modules"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
|
||||||
|
merged_lora_config = copy.copy(lora_configs[0])
|
||||||
|
|
||||||
|
# merge target modules as a union operation
|
||||||
|
merged_target_modules = sorted(
|
||||||
|
set(
|
||||||
|
module
|
||||||
|
for lora_config in lora_configs
|
||||||
|
for module in lora_config.target_modules
|
||||||
|
)
|
||||||
|
)
|
||||||
|
merged_lora_config.target_modules = merged_target_modules
|
||||||
|
|
||||||
|
return merged_lora_config
|
|
@ -0,0 +1,108 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# From: https://github.com/huggingface/peft/pull/1364
|
||||||
|
# Copyright 2024-present the HuggingFace Inc. team.
|
||||||
|
# Modifications by Predibase, Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
|
||||||
|
`density`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):The tensor to prune.
|
||||||
|
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||||
|
"""
|
||||||
|
mask = torch.zeros_like(tensor).reshape(-1)
|
||||||
|
k = int(density * tensor.reshape(-1).shape[0])
|
||||||
|
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
|
||||||
|
mask[top_k[1]] = 1
|
||||||
|
return tensor * mask.reshape(tensor.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
|
||||||
|
`density`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):The tensor to prune.
|
||||||
|
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||||
|
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
|
||||||
|
"""
|
||||||
|
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
|
||||||
|
pruned_tensor = tensor * mask
|
||||||
|
if rescale:
|
||||||
|
torch.div(input=pruned_tensor, other=density)
|
||||||
|
return pruned_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def prune(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
density: float,
|
||||||
|
method: Literal["magnitude", "random"],
|
||||||
|
rescale: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Prune the values of task tensors based on the `method`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):The tensor to prune.
|
||||||
|
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||||
|
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
|
||||||
|
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
|
||||||
|
"""
|
||||||
|
if density >= 1:
|
||||||
|
return tensor
|
||||||
|
elif density < 0:
|
||||||
|
raise ValueError("Density should be >= 0, got {density}")
|
||||||
|
if method == "magnitude":
|
||||||
|
return magnitude_based_pruning(tensor, density)
|
||||||
|
elif method == "random":
|
||||||
|
return random_pruning(tensor, density, rescale=rescale)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method {method}")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_majority_sign_mask(
|
||||||
|
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`):The tensor to get the mask from.
|
||||||
|
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
|
||||||
|
"""
|
||||||
|
|
||||||
|
sign = tensor.sign()
|
||||||
|
if method == "total":
|
||||||
|
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
|
||||||
|
elif method == "frequency":
|
||||||
|
sign_magnitude = sign.sum(dim=0)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Unimplemented mask method "{method}"')
|
||||||
|
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
|
||||||
|
return sign == majority_sign
|
||||||
|
|
||||||
|
|
||||||
|
def disjoint_merge(task_tensors, majority_sign_mask):
|
||||||
|
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
|
||||||
|
num_params_preserved = majority_sign_mask.sum(dim=0)
|
||||||
|
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
|
|
@ -0,0 +1,62 @@
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def find_segments(
|
||||||
|
adapter_indices: Union[torch.Tensor, List[int]]
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
segments = [0]
|
||||||
|
segment_indices = []
|
||||||
|
|
||||||
|
if isinstance(adapter_indices, torch.Tensor):
|
||||||
|
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
|
||||||
|
adapter_indices = adapter_indices.cpu().tolist()
|
||||||
|
|
||||||
|
start_index = 0
|
||||||
|
for i in range(1, len(adapter_indices)):
|
||||||
|
if adapter_indices[i] != adapter_indices[i - 1]:
|
||||||
|
segments.append(i)
|
||||||
|
segment_indices.append(adapter_indices[i - 1])
|
||||||
|
start_index = i
|
||||||
|
|
||||||
|
# Handle the last segment
|
||||||
|
if start_index < len(adapter_indices):
|
||||||
|
segments.append(len(adapter_indices))
|
||||||
|
segment_indices.append(adapter_indices[-1])
|
||||||
|
|
||||||
|
return segments, segment_indices
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentConcatBuilder:
|
||||||
|
def __init__(self):
|
||||||
|
self.adapter_segment_indices = []
|
||||||
|
self.adapter_segment_tensors = []
|
||||||
|
|
||||||
|
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
|
||||||
|
# Update adapter segments
|
||||||
|
if self.adapter_segment_tensors:
|
||||||
|
# Because we have already processed at least one batch, remove the 0 start index
|
||||||
|
# from this batch denoting the beginning of the segment, then offset all segment
|
||||||
|
# positions by the value of the last segment in the previous batch to account for
|
||||||
|
# the concatenation.
|
||||||
|
adapter_segments = (
|
||||||
|
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.adapter_segment_indices
|
||||||
|
and self.adapter_segment_indices[-1] == segment_indices[0]
|
||||||
|
):
|
||||||
|
# If the last segment in the previous batch is the same as the first segment in this batch,
|
||||||
|
# then we merge them together into a single segment. In effect, this means removing it from
|
||||||
|
# the segment indices of this batch, and extending the segment span by removing the segment
|
||||||
|
# end index from the previous batch.
|
||||||
|
segment_indices = segment_indices[1:]
|
||||||
|
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
|
||||||
|
|
||||||
|
self.adapter_segment_indices.extend(segment_indices)
|
||||||
|
self.adapter_segment_tensors.append(adapter_segments)
|
||||||
|
|
||||||
|
def build(self) -> Tuple[torch.Tensor, List[int]]:
|
||||||
|
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
|
|
@ -0,0 +1,242 @@
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: add build steps for Punica kernels
|
||||||
|
# import punica_kernels as _kernels
|
||||||
|
import punica.ops as _kernels
|
||||||
|
|
||||||
|
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
|
||||||
|
_kernels = None
|
||||||
|
HAS_SGMV = False
|
||||||
|
|
||||||
|
|
||||||
|
MIN_SGMV_RANK = 8
|
||||||
|
MIN_RANK_CUSTOM = 16
|
||||||
|
MAX_RANK_CUSTOM = 128
|
||||||
|
SGMV_BLOCK_SIZE = 16
|
||||||
|
BGMV_MAX_RANK = 64
|
||||||
|
|
||||||
|
|
||||||
|
def has_sgmv() -> bool:
|
||||||
|
return HAS_SGMV
|
||||||
|
|
||||||
|
|
||||||
|
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
||||||
|
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
||||||
|
if not has_sgmv():
|
||||||
|
return t
|
||||||
|
|
||||||
|
# tensor parallelism will result in effective rank being divided by world_size,
|
||||||
|
# so we need to scale the min rank to offset that effect
|
||||||
|
min_rank = MIN_SGMV_RANK * world_size
|
||||||
|
|
||||||
|
# if we're at or below the min rank, pad up to the min rank
|
||||||
|
# otherwise, pad to the nearest multiple of the block size
|
||||||
|
current_rank = t.size(dim)
|
||||||
|
target_rank = (
|
||||||
|
min_rank
|
||||||
|
if current_rank <= min_rank
|
||||||
|
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
||||||
|
)
|
||||||
|
if current_rank == target_rank:
|
||||||
|
return t
|
||||||
|
|
||||||
|
pad_size = target_rank - current_rank
|
||||||
|
|
||||||
|
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
||||||
|
pad = [0, 0] * t.dim()
|
||||||
|
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
||||||
|
pad = tuple(pad)
|
||||||
|
|
||||||
|
return F.pad(t, pad, mode="constant", value=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def use_cutlass_shrink(lora_rank: int) -> bool:
|
||||||
|
return lora_rank < MIN_RANK_CUSTOM
|
||||||
|
|
||||||
|
|
||||||
|
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
||||||
|
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
||||||
|
return t.transpose(0, 1)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
|
||||||
|
def add_lora_sgmv_cutlass(
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_ptr: torch.Tensor,
|
||||||
|
wb_ptr: torch.Tensor,
|
||||||
|
s_start: torch.Tensor,
|
||||||
|
s_end: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
lora_rank: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Semantics:
|
||||||
|
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
|
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||||
|
Weight matrix shape: `[num_layers, R, H1]`.
|
||||||
|
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||||
|
Weight matrix shape: `[num_layers, R, H2]`.
|
||||||
|
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
||||||
|
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
||||||
|
layer_idx: Layer index of the weight matrices.
|
||||||
|
"""
|
||||||
|
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
||||||
|
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
||||||
|
_add_lora_sgmv_cutlass_legacy(
|
||||||
|
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
||||||
|
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||||
|
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
||||||
|
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||||
|
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
||||||
|
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_lora_sgmv_cutlass_legacy(
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_ptr: torch.Tensor,
|
||||||
|
wb_ptr: torch.Tensor,
|
||||||
|
s_start: torch.IntTensor,
|
||||||
|
s_end: torch.IntTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
lora_rank: int,
|
||||||
|
):
|
||||||
|
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||||
|
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
||||||
|
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||||
|
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||||
|
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
||||||
|
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
||||||
|
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
|
||||||
|
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tmp_expand_size(size: int) -> int:
|
||||||
|
return _kernels.sgmv_cutlass_tmp_size(size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tmp_tensors(
|
||||||
|
nsegments: int, lora_rank: int, device: torch.device
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if use_cutlass_shrink(lora_rank):
|
||||||
|
tmp = get_tmp_tensor_for_size(nsegments, device)
|
||||||
|
return tmp, tmp
|
||||||
|
else:
|
||||||
|
tmp_shrink = get_tmp_tensor(device)
|
||||||
|
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
||||||
|
return tmp_shrink, tmp_expand
|
||||||
|
|
||||||
|
|
||||||
|
def lora_a_sgmv_cutlass(
|
||||||
|
x: torch.Tensor,
|
||||||
|
tmp: torch.Tensor,
|
||||||
|
wa_ptr: torch.Tensor,
|
||||||
|
s_start: torch.IntTensor,
|
||||||
|
s_end: torch.IntTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
lora_rank: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||||
|
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
||||||
|
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||||
|
else:
|
||||||
|
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def lora_b_sgmv_cutlass(
|
||||||
|
y: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
tmp: torch.Tensor,
|
||||||
|
wb_ptr: torch.Tensor,
|
||||||
|
s_start: torch.IntTensor,
|
||||||
|
s_end: torch.IntTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
):
|
||||||
|
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Semantics:
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
* scale
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
|
v: Shape: `[B, R]`. Temporary vector.
|
||||||
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
|
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
|
||||||
|
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
|
||||||
|
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||||
|
layer_idx: Layer index of LoRA weights.
|
||||||
|
scale: Scaling factor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_a_bgmv(
|
||||||
|
v: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
wa_T_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
):
|
||||||
|
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def add_lora_b_bgmv(
|
||||||
|
y: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
wb_T_all: torch.Tensor,
|
||||||
|
indicies: torch.LongTensor,
|
||||||
|
layer_idx: int,
|
||||||
|
):
|
||||||
|
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def segmented_matmul(
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: List[torch.Tensor],
|
||||||
|
b: List[torch.Tensor],
|
||||||
|
s_start: torch.IntTensor,
|
||||||
|
s_end: torch.IntTensor,
|
||||||
|
):
|
||||||
|
for i in range(len(w)):
|
||||||
|
if s_end[i] - s_start[i] <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
xi = x[s_start[i] : s_end[i]]
|
||||||
|
wi = w[i]
|
||||||
|
bi = b[i]
|
||||||
|
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
|
Loading…
Reference in New Issue