Enable multiple LoRa adapters (#2010)

* feat: first draft load multiple lora

* feat: load weights within layer and refactor lora pass

* fix: refactor and reduce lora math

* feat: baseline impl single request multi lora support

* feat: prefer lorax implementation and port loading logic

* fix: prefer adapter_data and refactors

* feat: perfer loraxs custom punica kernels and add mlp loras

* fix: adjust batch for bgmv

* fix: adjust adapter_segments logic when in batch

* fix: refactor and move changes to v3 proto

* fix: pass model_id for all flash causal lms

* fix: pass model_id for all causal and seq2seq lms

* fix: add model_id to model test

* feat: add lora support to mistral and refactors

* feat: prefer model id in request

* fix: include rust code for adapter id

* feat: bump launcher and add new lora docs

* feat: support base model generation and refactors

* fix: rename doc to retry ci build

* feat: support if vlm models

* fix: add adapter_data param and avoid missing layers

* fix: add adapter_data param to phi and neox

* fix: update all models forwards to include adapter_data

* fix: add model_id to IdeficsCausalLM

* Update lora.md

Fixed a typo

* Update lora.md

Fixing spam image

* fix: add lora kernel to dockerfile, support running without kernels and refactors

* fix: avoid dockerfile conflict

* fix: refactors and adjust flash llama lora logic

* fix: skip llama test due to CI issue (temp)

* fix: skip llama test CI (temp) 2

* fix: revert skips and prefer updated ci token for tests

* fix: refactors and helpful comments

* fix: add noop in TensorParallelAdapterRowLinear too

* fix: refactor and move shard_lora_weights logic

* fix: exit early if no adapter_data

---------

Co-authored-by: Derek <datavistics@gmail.com>
This commit is contained in:
drbh 2024-06-25 14:46:27 -04:00 committed by GitHub
parent a2a97b05d6
commit 04e1af94d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
79 changed files with 2785 additions and 76 deletions

View File

@ -145,6 +145,13 @@ COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Lorax Punica kernels
FROM kernel-builder as lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy build artifacts from marlin kernels builder
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
# CMD ["--json-output"]

View File

@ -157,6 +157,7 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
})
.collect();

View File

@ -60,6 +60,9 @@
- local: conceptual/speculation
title: Speculation (Medusa, ngram)
- local: conceptual/guidance
title: How Guidance Works (via outlines)
title: How Guidance Works (via outlines
- local: conceptual/lora
title: LoRA (Low-Rank Adaptation)
title: Conceptual Guides

View File

@ -416,6 +416,14 @@ Options:
[env: MAX_CLIENT_BATCH_SIZE=]
[default: 4]
```
## LORA_ADAPTERS
```shell
--lora-adapters <LORA_ADAPTERS>
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
[env: LORA_ADAPTERS=]
```
## HELP
```shell

View File

@ -0,0 +1,65 @@
# LoRA (Low-Rank Adaptation)
## What is LoRA?
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
## How is it used?
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
- fine-tuning a language model on a small dataset
- fine-tuning a language model on a domain-specific dataset
- fine-tuning a language model on a dataset with limited labels
## Optimizing Inference with LoRA
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
## Serving multiple LoRA adapters with TGI
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
### Specifying LoRA models
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
```bash
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
In the server logs, you will see the following message:
```txt
Loading adapter weights into model: predibase/customer_support
Loading adapter weights into model: predibase/dbpedia
```
## Generate text
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
```json
curl 127.0.0.1:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "Hello who are you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/customer_support"
}
}'
```
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
An updated tutorial with detailed examples will be published soon. Stay tuned!

View File

@ -452,6 +452,11 @@ struct Args {
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
/// startup that will be available to callers via the `adapter_id` field in a request.
#[clap(long, env)]
lora_adapters: Option<String>,
}
#[derive(Debug)]
@ -485,6 +490,7 @@ fn shard_manager(
max_total_tokens: usize,
max_batch_size: Option<usize>,
max_input_tokens: usize,
lora_adapters: Option<String>,
otlp_endpoint: Option<String>,
otlp_service_name: String,
log_level: LevelFilter,
@ -620,6 +626,11 @@ fn shard_manager(
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
}
// Lora Adapters
if let Some(lora_adapters) = lora_adapters {
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
}
// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
@ -1060,6 +1071,7 @@ fn spawn_shards(
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
let max_batch_size = args.max_batch_size;
let lora_adapters = args.lora_adapters.clone();
thread::spawn(move || {
shard_manager(
model_id,
@ -1085,6 +1097,7 @@ fn spawn_shards(
max_total_tokens,
max_batch_size,
max_input_tokens,
lora_adapters,
otlp_endpoint,
otlp_service_name,
max_log_level,

View File

@ -134,6 +134,8 @@ message Request {
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
}
message Batch {

View File

@ -177,6 +177,7 @@ impl Client {
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;

View File

@ -244,6 +244,7 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,

View File

@ -429,6 +429,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -351,6 +351,7 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
adapter_id: entry.request.adapter_id.clone(),
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -491,6 +492,7 @@ mod tests {
stop_sequences: vec![],
},
top_n_tokens: 0,
adapter_id: None,
},
response_tx,
span: info_span!("entry"),

View File

@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub grammar: Option<GrammarType>,
/// Lora adapter id
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub adapter_id: Option<String>,
}
fn default_max_new_tokens() -> Option<u32> {
@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
seed: None,
top_n_tokens: None,
grammar: None,
adapter_id: None,
}
}

View File

@ -673,6 +673,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
},
})
.collect();
@ -1115,6 +1116,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar,
..Default::default()
},
};

View File

@ -202,6 +202,7 @@ impl Validation {
decoder_input_details,
top_n_tokens,
grammar,
adapter_id,
..
} = request.parameters;
@ -383,6 +384,7 @@ impl Validation {
parameters,
stopping_parameters,
top_n_tokens,
adapter_id,
})
}
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
pub top_n_tokens: u32,
pub adapter_id: Option<String>,
}
#[derive(Error, Debug)]

View File

@ -4,6 +4,7 @@ include Makefile-vllm
include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica
unit-tests:
pytest -s -vv -m "not private" tests

View File

@ -0,0 +1,12 @@
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
build-lorax-punica:
if [ ! -d 'lorax-punica' ]; then \
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
fi
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
cd lorax-punica && git submodule update --init --recursive
cd lorax-punica/server/punica_kernels && python setup.py build
install-lorax-punica: build-lorax-punica
cd lorax-punica/server/punica_kernels && python setup.py install

View File

@ -17,7 +17,12 @@ def get_test_model():
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
"test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
)
return model

View File

@ -0,0 +1,13 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]

View File

@ -0,0 +1,44 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass

View File

@ -0,0 +1,482 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
@classmethod
def load(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def key(cls) -> str:
return "lora"
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v

View File

@ -0,0 +1,158 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0

View File

@ -79,6 +79,18 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
logger.warning(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
@ -93,6 +105,7 @@ def serve(
)
server.serve(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
@ -113,6 +126,7 @@ def download_weights(
logger_level: str = "INFO",
json_output: bool = False,
trust_remote_code: bool = False,
merge_lora: bool = False,
):
# Remove default handler
logger.remove()
@ -143,18 +157,28 @@ def download_weights(
) is not None
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# TODO: maybe reverse the default value of merge_lora?
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
try:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
except Exception:
pass
try:
import json

View File

@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)

View File

@ -0,0 +1,286 @@
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
add_lora_a_bgmv,
add_lora_b_bgmv,
has_sgmv,
lora_a_sgmv_cutlass,
lora_b_sgmv_cutlass,
orient_for_rank,
)
if TYPE_CHECKING:
from text_generation_server.adapters import AdapterBatchData
from text_generation_server.adapters.lora import BatchLoraWeights
class LoraLinear(nn.Module):
def __init__(
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
):
super().__init__()
self.base_layer = base_layer
self.layer_id = layer_id
self.process_group = process_group
def forward_layer_type(
self,
result: torch.Tensor,
input: torch.Tensor,
adapter_data: "AdapterBatchData",
layer_type: str,
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get("lora") if data is not None else None
)
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
# The 'result' tensor represents the full output, which can vary in size based on
# the layer type (e.g., attention vs. feed-forward layers). We define the current
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
# This approach ensures accurate LoRA application across various layer sizes and
# configurations, adapting to different model architectures and parallelization strategies.
#
# Example scenarios where this is necessary:
# 1. The adapter's size doesn't evenly divide across GPUs.
# 2. We're processing the last segment which might be smaller.
# 3. Different projection layers (q, k, v) have different sizes.
if end_idx - start_idx != result.shape[1]:
proj = torch.zeros_like(result[:, start_idx:end_idx])
else:
proj = result
for r, rank_segments in data.rank_data.items():
lora_a_ptr = rank_segments.lora_a_ptr
lora_b_ptr = rank_segments.lora_b_ptr
if lora_a_ptr is None or lora_b_ptr is None:
raise ValueError("LoRA data is missing")
if data.use_sgmv:
# Use SGMV for prefill
v = lora_a_sgmv_cutlass(
input,
rank_segments.tmp_shrink,
lora_a_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
r,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
lora_b_sgmv_cutlass(
proj,
v,
rank_segments.tmp_expand,
lora_b_ptr,
rank_segments.segment_starts,
rank_segments.segment_ends,
self.layer_id,
)
else:
# Use BGMV for decode
v = torch.zeros(
(input.size(0), r), dtype=input.dtype, device=input.device
)
# TODO: error with [-1, 0], but not [0, -1]
add_lora_a_bgmv(
v,
input,
lora_a_ptr,
rank_segments.indices,
self.layer_id,
)
if self.process_group.size() > 1:
v = self.collect_lora_a(v)
add_lora_b_bgmv(
proj,
v,
lora_b_ptr,
rank_segments.indices,
self.layer_id,
)
if end_idx - start_idx != result.shape[1]:
result[:, start_idx:end_idx] += proj
else:
for adapter_index in adapter_data.meta.adapter_set:
if data is not None and data.has_adapter(adapter_index):
adapter_mask = (
(adapter_data.meta.adapter_indices == adapter_index)
.to(input.dtype)
.view(-1, 1)
)
layer_result = self.forward_lora(
input, data, adapter_index, adapter_mask
)
result[:, start_idx:end_idx] += layer_result
return result
def forward_lora(
self,
input: torch.Tensor,
data: "BatchLoraWeights",
adapter_index: int,
adapter_mask: torch.Tensor,
) -> torch.Tensor:
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
lora_a = orient_for_rank(lora_a, lora_b.size(0))
a_out = input @ lora_a
if self.process_group.size() > 1:
a_out = self.collect_lora_a(a_out)
result = (a_out @ lora_b) * adapter_mask
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Implemented in subclasses")
class TensorParallelMultiAdapterLinear(LoraLinear):
def __init__(
self,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
super().__init__(base_layer, layer_id, process_group)
self.layer_names = layer_names
self.sizes = sizes
@classmethod
def load(
cls,
base_layer: nn.Module,
layer_id: int,
layer_names: List[str],
sizes: List[int],
process_group: ProcessGroup,
):
return TensorParallelMultiAdapterLinear(
base_layer, layer_id, layer_names, sizes, process_group
)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
# noop if no layer names are provided (e.g. for models without adapters)
if self.layer_names is None:
return result
# handle models like Bloom that have inputs of shape
# (batch_size, sequence_length, hidden_size)
# we need to reshape them to (batch_size * sequence_length, hidden_size)
# for the LoRA computation, then reshape back
prev_shape = result.shape
is_3d = len(input.shape) >= 3
if is_3d:
input = input.reshape(-1, input.shape[-1])
result = result.reshape(-1, result.shape[-1])
offset = 0
for i, layer_name in enumerate(self.layer_names):
start_idx = offset // self.process_group.size()
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
# different projection sizes across layers and model architectures.
if self.sizes is not None:
offset += self.sizes[i]
end_idx = offset // self.process_group.size()
else:
end_idx = result.shape[1]
result = self.forward_layer_type(
result, input, adapter_data, layer_name, start_idx, end_idx
)
if is_3d:
result = result.reshape(prev_shape)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
gathered_tensors = [
torch.empty_like(a_out) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(gathered_tensors, a_out)
return torch.cat(gathered_tensors, dim=1)
class TensorParallelAdapterRowLinear(LoraLinear):
def __init__(self, base_layer, layer_id, layer_name, process_group):
super().__init__(base_layer, layer_id, process_group)
self.layer_name = layer_name
@classmethod
def load(cls, base_layer, layer_id, layer_name, process_group):
return cls(base_layer, layer_id, layer_name, process_group)
def forward(
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
) -> torch.Tensor:
result = self.base_layer(input)
if self.layer_name is None:
return result
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
stride = result.shape[-1] // self.process_group.size()
start_idx = self.process_group.rank() * stride
end_idx = (self.process_group.rank() + 1) * stride
self.forward_layer_type(
result, input, adapter_data, self.layer_name, start_idx, end_idx
)
return result
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
#
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
torch.distributed.all_reduce(a_out, group=self.process_group)
return a_out

View File

@ -6,7 +6,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from typing import Optional, List
from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
@ -253,6 +253,7 @@ for data in ModelType:
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -595,6 +596,7 @@ def get_model(
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))

View File

@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -538,6 +538,7 @@ class CausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,

View File

@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,

View File

@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(

View File

@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -51,43 +53,61 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights):
def load_attention(config, prefix, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
sizes = None
prefixes = None
# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
prefix = f"{prefix}.qkv_proj"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
prefix = f"{prefix}.W_pack"
base_layer = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
prefix=prefix,
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
else:
prefixes = ["q_proj", "k_proj", "v_proj"]
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
index: int,
prefix: str,
config,
weights,
@ -121,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -145,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -190,11 +220,13 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, index):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -209,29 +241,54 @@ class LlamaMLP(nn.Module):
),
)
)
prefixes = None
sizes = None
# Fuse gate and up proj
bias = getattr(config, "mlp_bias", False)
if config.model_type == "phi3":
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
config,
prefix=f"{prefix}.gate_up_proj",
weights=weights,
bias=bias,
)
else:
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
prefixes = [f"gate_proj", f"up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=bias,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=bias,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
@ -239,7 +296,7 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
@ -253,20 +310,27 @@ class LlamaMLP(nn.Module):
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -289,6 +353,7 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -303,6 +368,7 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
)
# faster post attention rms norm
@ -310,7 +376,7 @@ class FlashLlamaLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -325,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
@ -360,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -382,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
slots,
input_lengths,
max_s,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -423,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
@ -436,6 +506,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else -1
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load_multi(
query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
)
# faster post attention rms norm
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
attn_output, res
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
)
for layer_id in range(config.num_hidden_layers)
]
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
input_lengths,
max_s,
prefill_cache_indices,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s,
true_max_s,
prefill_cache_indices,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:

View File

@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.gpt_neox(
input_ids,

View File

@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.

View File

@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,

View File

@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:

View File

@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,

View File

@ -483,6 +483,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,

View File

@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:

View File

@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:

View File

@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:

View File

@ -13,6 +13,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
empty_cache,
@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
adapter_meta: AdapterBatchMetadata
# Number of blocks in this batch
num_blocks: int
# Maximum number of blocks
@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
adapter_indices_list = []
adapter_set = set()
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch):
max_length, input_length + max_new_tokens + speculative_length
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch):
input_lengths, dtype=torch.int32, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
speculative_ids=None,
)
@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
num_blocks = 0
max_blocks = 0
@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx])
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
self.requests[idx].adapter_id, 0
)
adapter_set.add(adapter_index)
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch):
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch):
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
return type(self)(
batch_id=self.batch_id,
requests=requests,
@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)
@classmethod
@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = []
block_tables = []
@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_slots = 0
cumulative_adapter_indices_size = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch):
else None
)
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
return cls(
batch_id=batches[0].batch_id,
requests=requests,
@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
)
def __len__(self):
@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
@ -738,6 +816,7 @@ class FlashCausalLM(Model):
self.kv_cache = []
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
@ -895,12 +974,13 @@ class FlashCausalLM(Model):
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks
+ batch_num_blocks
)
del batch
@ -1001,7 +1081,7 @@ class FlashCausalLM(Model):
)
def forward(
self, batch: FlashCausalLMBatch
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
@ -1080,6 +1160,7 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
@ -1116,7 +1197,34 @@ class FlashCausalLM(Model):
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
out, speculative_logits = self.forward(batch)
# Update adapter indices for speculative tokens (if present)
adapter_meta = batch.adapter_meta
if batch.speculative_ids is not None:
B, speculative_length = batch.speculative_ids.shape
new_length = speculative_length + 1
adapter_indices = (
adapter_meta.adapter_indices.unsqueeze(-1)
.expand(B, new_length)
.reshape(-1)
)
adapter_segments = adapter_meta.adapter_segments * new_length
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_meta.adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_meta.segment_indices,
)
# Assign pointers to adapter weights
# TODO(travis): don't update this if indices haven't changed
adapter_data = AdapterBatchData.from_meta(
adapter_meta,
self.layer_to_adapter_weights,
prefill,
batch.prefill_head_indices,
)
out, speculative_logits = self.forward(batch, adapter_data)
if prefill:
next_token_logits = (
@ -1128,8 +1236,13 @@ class FlashCausalLM(Model):
if prefill_logprobs
else speculative_logits
)
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
else:
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
speculate = get_speculate()
(
@ -1195,6 +1308,12 @@ class FlashCausalLM(Model):
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
@ -1220,6 +1339,16 @@ class FlashCausalLM(Model):
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if prefill and prefill_logprobs:
# Get prefill logprobs

View File

@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM):
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -1,9 +1,10 @@
import os
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
@ -13,12 +14,24 @@ from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
def __init__(
@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM):
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM):
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM):
def __init__(
self,
@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM):
model.model.head_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(

View File

@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),

View File

@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),

View File

@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),

View File

@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral):
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),

View File

@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -1,6 +1,7 @@
import torch
import os
from loguru import logger
from typing import Dict
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
@ -32,3 +33,14 @@ MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX: Dict[str, int] = None
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX = adapter_to_index

View File

@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM):
torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -634,6 +634,7 @@ class IdeficsCausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -453,6 +453,7 @@ class Mamba(Model):
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -2,12 +2,24 @@ import inspect
import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
from text_generation_server.utils.adapter import (
load_and_merge_adapters,
AdapterParameters,
AdapterSource,
)
from loguru import logger
BASE_MODEL_ADAPTER_ID = "__base_model__"
B = TypeVar("B", bound=Batch)
@ -15,6 +27,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
requires_padding: bool,
@ -24,7 +37,9 @@ class Model(ABC):
world_size: int = 1,
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
):
self.model_id = model_id
self.model = model.eval()
self.tokenizer = tokenizer
@ -42,6 +57,13 @@ class Model(ABC):
self.world_size = world_size
self.sliding_window = sliding_window if sliding_window != -1 else None
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.target_to_layer = self.adapter_target_to_layer()
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
if speculate is None:
speculate = get_speculate()
self.speculate = speculate
@ -119,3 +141,136 @@ class Model(ABC):
raise RuntimeError(
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
)
@property
def supports_adapter_loading(self) -> bool:
return False
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
return {}
@property
def adapter_layers(self) -> List[str]:
return []
@property
def default_traced_adapter_layers(self) -> List[str]:
return []
def get_num_layers_for_type(self, layer_type: str) -> int:
return 0
def is_row_parallel(self, layer_type: str) -> bool:
return False
@property
def max_speculative_tokens(self) -> int:
return max(
[
weights.max_speculative_tokens
for weights in self.layer_to_adapter_weights.values()
],
default=0,
)
def load_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
api_token: str,
dynamic: bool = True,
):
"""Loads adapter weights from disk / host memory on the GPU.
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if dynamic and not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
logger.info(
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
)
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
self.model_id,
adapter_parameters,
adapter_source,
adapter_index,
weight_names,
api_token,
False,
)
unused_weight_names = adapter_weight_names.copy()
for layer_name in self.adapter_layers:
adapter_weights = adapter_config.load_batched_adapter_weights(
self,
module_map,
layer_name,
unused_weight_names,
dynamic,
)
if adapter_weights is None:
continue
layer_weights = self.layer_to_adapter_weights[layer_name]
layer_weights.add_adapter(adapter_index, adapter_weights)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
self.loaded_adapters.add(adapter_index)
def offload_adapter(
self,
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
if not self.dynamic_adapter_loading_enabled:
raise ValueError(
f"This model was initialized with the adapter {self.static_adapter_id} "
f"and therefore does not support dynamic adapter loading. "
f"Please initialize a new model instance from the base model in "
f"order to use the dynamic adapter loading feature."
)
for layer_name in self.adapter_layers:
if layer_name in self.layer_to_adapter_weights:
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
self.loaded_adapters.remove(adapter_index)

View File

@ -90,6 +90,7 @@ class MPTSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,

View File

@ -63,6 +63,7 @@ class OPTSharded(CausalLM):
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -60,6 +60,7 @@ class Phi(CausalLM):
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -62,6 +62,7 @@ class RW(CausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -62,6 +62,7 @@ class SantaCoder(CausalLM):
)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -575,6 +575,7 @@ class Seq2SeqLM(Model):
tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -222,7 +222,9 @@ class VlmCausalLM(BaseFlashMistral):
return VlmCausalLMBatch
def forward(
self, batch: VlmCausalLMBatch
self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:

View File

@ -29,7 +29,10 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.utils.adapter import (
AdapterParameters,
)
class SignalHandler:
@ -192,6 +195,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def serve(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
@ -203,6 +207,7 @@ def serve(
):
async def serve_inner(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool = False,
quantize: Optional[str] = None,
@ -211,6 +216,7 @@ def serve(
trust_remote_code: bool = False,
):
unix_socket_template = "unix://{}-{}"
adapter_to_index = {}
if sharded:
server_urls = [
unix_socket_template.format(uds_path, rank)
@ -224,6 +230,7 @@ def serve(
try:
model = get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
@ -232,10 +239,33 @@ def serve(
trust_remote_code,
max_input_tokens,
)
if len(lora_adapter_ids) > 0:
for index, adapter_id in enumerate(lora_adapter_ids):
# TODO: improve non merged adapter loading and long term
# improve adapter loading as a whole
adapter_parameters = AdapterParameters(
adapter_ids=[adapter_id],
weights=None, # will be set to 1
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
model.load_adapter(
adapter_parameters,
None, # adapter_source
adapter_index,
None, # api_token
False, # dynamic
)
except Exception:
logger.exception("Error when initializing model")
raise
set_adapter_to_index(adapter_to_index)
server = aio.server(
interceptors=[
ExceptionInterceptor(),
@ -266,6 +296,13 @@ def serve(
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
)
)

View File

@ -0,0 +1,196 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/adapter.py
# License: Apache License Version 2.0, January 2004
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.utils.merges.strategies import merge_adapters
from text_generation_server.utils import hub
from text_generation_server.adapters.lora import LoraConfig
if TYPE_CHECKING:
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterParameters:
adapter_ids: Tuple[str]
weights: Tuple[float]
merge_strategy: NotImplemented
density: float
majority_sign_method: NotImplemented
@dataclass
class AdapterSource:
adapter_id: str
model_id: str
revision: str
def load_and_merge_adapters(
model_id: str,
adapter_parameters: AdapterParameters,
adapter_source: str,
adapter_index: int,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
return load_module_map(
model_id,
adapter_parameters.adapter_ids[0],
adapter_source,
weight_names,
api_token,
trust_remote_code,
)
adapter_params = AdapterParametersContainer(
adapter_parameters, adapter_source, adapter_index
)
return _load_and_merge(
model_id, adapter_params, weight_names, api_token, trust_remote_code
)
@dataclass
class AdapterParametersContainer:
adapter_parameters: AdapterParameters
adapter_source: str
adapter_index: int
def __hash__(self) -> int:
return self.adapter_index
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters
adapters_to_merge = []
merged_weight_names = set()
tokenizer = None
for adapter_id in params.adapter_ids:
if adapter_id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter_id,
adapter_params.adapter_source,
weight_names,
api_token,
trust_remote_code,
)
)
adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names)
if tokenizer is None:
tokenizer = adapter_tokenizer
if len(adapters_to_merge) == 0:
raise ValueError("No adapters to merge.")
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
return module_map, adapter_config, merged_weight_names, tokenizer
def check_architectures(
model_id: str,
adapter_id: str,
adapter_config: "AdapterConfig",
trust_remote_code: bool = False,
):
try:
if not adapter_config.base_model_name_or_path:
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return
expected_config = AutoConfig.from_pretrained(
model_id, trust_remote_code=trust_remote_code
)
model_config = AutoConfig.from_pretrained(
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
)
except Exception as e:
warnings.warn(
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
)
return
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
adapter_id: str,
adapter_source: str,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
revision = "main"
adapter_config = LoraConfig.load(adapter_id, api_token)
if adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = hub._cached_adapter_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(
adapter_config.config_path,
token=api_token,
trust_remote_code=trust_remote_code,
)
except Exception:
# Adapter does not have a tokenizer, so fallback to base model tokenizer
adapter_tokenizer = None
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
adapter_weights, weight_names
)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer

View File

@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_adapter_weight_files(
adapter_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(adapter_id, revision)
if not d:
return []
filenames = _adapter_weight_files_from_dir(d, extension)
return filenames
def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
@ -60,6 +71,33 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
return filenames
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f
]
return filenames
def _adapter_config_files_from_dir(d: Path) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(".json") and "arguments" not in f and "args" not in f
]
return filenames
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:

View File

@ -0,0 +1,223 @@
import copy
from abc import ABC
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
import torch
class AdapterParameters:
def __init__(
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
):
self.adapter_ids = adapter_ids
self.weights = weights
self.merge_strategy = merge_strategy
self.density = density
self.majority_sign_method = majority_sign_method
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor:
if isinstance(tensors, torch.Tensor):
t = tensors
else:
t = torch.stack(tensors, dim=0)
# element-wise weighting of each task tensor
# need to unsqueeze weights to match task tensor dimensions
# for multiplication to apply element-wise
while len(t.shape) > len(w.shape):
w = w.unsqueeze(-1)
return t * w
class MergeStrategy(ABC):
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
raise NotImplementedError()
class LinearMerge(MergeStrategy):
def __init__(self, **kwargs):
pass
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class TiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
class DareLinearMerge(MergeStrategy):
def __init__(self, density: float, **kwargs):
self.density = density
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
weighted_task_tensors = _apply_weights(task_tensors, weights)
return weighted_task_tensors.sum(dim=0)
class DareTiesMerge(MergeStrategy):
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
self.density = density
self.majority_sign_method = majority_sign_method
def merge(
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
) -> torch.Tensor:
# sparsify
task_tensors = [
prune(tensor, self.density, method="random", rescale=True)
for tensor in task_tensors
]
task_tensors = torch.stack(task_tensors, dim=0)
# elect sign before applying weights
majority_sign_mask = calculate_majority_sign_mask(
task_tensors, method=self.majority_sign_method
)
weighted_task_tensors = _apply_weights(task_tensors, weights)
# disjoint merge
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
return mixed_task_tensors
strategy_registry: Dict[str, Type[MergeStrategy]] = {
"linear": LinearMerge,
"ties": TiesMerge,
"dare_linear": DareLinearMerge,
"dare_ties": DareTiesMerge,
}
def merge_adapters(
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
merge_params: AdapterParameters,
) -> Tuple["ModuleMap", "LoraConfig"]:
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
strategy_name = "linear"
weights = merge_params.weights
if not weights:
weights = torch.ones(len(adapters))
else:
weights = torch.tensor(weights)
merge_config = {
"density": merge_params.density,
# "majority_sign_method": MajoritySignMethodEnum.Name(
# merge_params.majority_sign_method
# ).lower(),
"majority_sign_method": "total",
}
merge_strategy = strategy_registry[strategy_name](**merge_config)
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(list))
)
lora_configs = []
weight_name_to_adapter_idx = defaultdict(list)
# input is list of (module_map, lora_config) tuples
# convert into dict[k][param_name] -> list of tensors
for idx, (module_map, lora_config) in enumerate(adapters):
for weight_name, data in module_map.items():
weight_name_to_adapter_idx[weight_name].append(idx)
for k, (param_data, param_name) in data.items():
module_maps[weight_name][k][param_name].append(param_data)
lora_configs.append(lora_config)
# validate lora configs are compatible
_validate_lora_configs(lora_configs)
# merge tensors for each module such that we have a single ModuleMap:
# dict[k] -> merged tensor
merged_module_map: "ModuleMap" = defaultdict(dict)
for weight_name, data in module_maps.items():
indices = weight_name_to_adapter_idx[weight_name]
param_weights = weights[indices]
for k, param_data in data.items():
for param_name, tensors in param_data.items():
merged_tensor = merge_strategy.merge(tensors, param_weights)
merged_module_map[weight_name][k] = (merged_tensor, param_name)
# merge lora configs
merged_lora_config = _merge_lora_configs(lora_configs)
return merged_module_map, merged_lora_config
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
# check that all configs have the same rank
ranks = set(lora_config.r for lora_config in lora_configs)
if len(ranks) > 1:
raise ValueError(
f"unable to merge adapters, lora configs have different ranks: {ranks}"
)
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
raise ValueError(
"unable to merge adapters, lora configs have no target modules"
)
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
merged_lora_config = copy.copy(lora_configs[0])
# merge target modules as a union operation
merged_target_modules = sorted(
set(
module
for lora_config in lora_configs
for module in lora_config.target_modules
)
)
merged_lora_config.target_modules = merged_target_modules
return merged_lora_config

View File

@ -0,0 +1,108 @@
# coding=utf-8
# From: https://github.com/huggingface/peft/pull/1364
# Copyright 2024-present the HuggingFace Inc. team.
# Modifications by Predibase, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import torch
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
"""
mask = torch.zeros_like(tensor).reshape(-1)
k = int(density * tensor.reshape(-1).shape[0])
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
mask[top_k[1]] = 1
return tensor * mask.reshape(tensor.shape)
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
pruned_tensor = tensor * mask
if rescale:
torch.div(input=pruned_tensor, other=density)
return pruned_tensor
def prune(
tensor: torch.Tensor,
density: float,
method: Literal["magnitude", "random"],
rescale: bool = False,
) -> torch.Tensor:
"""
Prune the values of task tensors based on the `method`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
"""
if density >= 1:
return tensor
elif density < 0:
raise ValueError("Density should be >= 0, got {density}")
if method == "magnitude":
return magnitude_based_pruning(tensor, density)
elif method == "random":
return random_pruning(tensor, density, rescale=rescale)
else:
raise ValueError(f"Unknown method {method}")
def calculate_majority_sign_mask(
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
):
"""
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
Args:
tensor (`torch.Tensor`):The tensor to get the mask from.
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
"""
sign = tensor.sign()
if method == "total":
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
elif method == "frequency":
sign_magnitude = sign.sum(dim=0)
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
return sign == majority_sign
def disjoint_merge(task_tensors, majority_sign_mask):
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
num_params_preserved = majority_sign_mask.sum(dim=0)
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)

View File

@ -1,5 +1,5 @@
import os
import json
from typing import Union
from loguru import logger
import torch
@ -43,3 +43,26 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)
def download_peft(
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
):
torch_dtype = torch.float16
try:
_model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model downloaded.")

View File

@ -0,0 +1,66 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/segments.py
# License: Apache License Version 2.0, January 2004
from typing import List, Tuple, Union
import torch
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []
if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
adapter_indices = adapter_indices.cpu().tolist()
start_index = 0
for i in range(1, len(adapter_indices)):
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i)
segment_indices.append(adapter_indices[i - 1])
start_index = i
# Handle the last segment
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
return segments, segment_indices
class SegmentConcatBuilder:
def __init__(self):
self.adapter_segment_indices = []
self.adapter_segment_tensors = []
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
# Update adapter segments
if self.adapter_segment_tensors:
# Because we have already processed at least one batch, remove the 0 start index
# from this batch denoting the beginning of the segment, then offset all segment
# positions by the value of the last segment in the previous batch to account for
# the concatenation.
adapter_segments = (
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
)
if (
self.adapter_segment_indices
and self.adapter_segment_indices[-1] == segment_indices[0]
):
# If the last segment in the previous batch is the same as the first segment in this batch,
# then we merge them together into a single segment. In effect, this means removing it from
# the segment indices of this batch, and extending the segment span by removing the segment
# end index from the previous batch.
segment_indices = segment_indices[1:]
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
self.adapter_segment_indices.extend(segment_indices)
self.adapter_segment_tensors.append(adapter_segments)
def build(self) -> Tuple[torch.Tensor, List[int]]:
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices

View File

@ -0,0 +1,248 @@
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/utils/sgmv.py
# License: Apache License Version 2.0, January 2004
import os
import warnings
from functools import lru_cache
from typing import List, Tuple
import torch
import torch.nn.functional as F
try:
import punica_kernels as _kernels
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
_kernels = None
HAS_SGMV = False
MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 64
def has_sgmv() -> bool:
return HAS_SGMV
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
if not has_sgmv():
return t
# tensor parallelism will result in effective rank being divided by world_size,
# so we need to scale the min rank to offset that effect
min_rank = MIN_SGMV_RANK * world_size
# if we're at or below the min rank, pad up to the min rank
# otherwise, pad to the nearest multiple of the block size
current_rank = t.size(dim)
target_rank = (
min_rank
if current_rank <= min_rank
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
)
if current_rank == target_rank:
return t
pad_size = target_rank - current_rank
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
pad = [0, 0] * t.dim()
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
pad = tuple(pad)
return F.pad(t, pad, mode="constant", value=0.0)
def use_cutlass_shrink(lora_rank: int) -> bool:
return lora_rank < MIN_RANK_CUSTOM
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
return t.transpose(0, 1)
return t
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.Tensor,
s_end: torch.Tensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H1]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
layer_idx: Layer index of the weight matrices.
"""
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
# Custom SGMV shrink only supports rank 16, 32, 64, 128
_add_lora_sgmv_cutlass_legacy(
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
)
return
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
def _add_lora_sgmv_cutlass_legacy(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
@lru_cache(maxsize=32)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
return torch.empty((size,), dtype=torch.uint8, device=device)
def get_tmp_expand_size(size: int) -> int:
return _kernels.sgmv_cutlass_tmp_size(size)
def get_tmp_tensors(
nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
if use_cutlass_shrink(lora_rank) and has_sgmv():
tmp = get_tmp_tensor_for_size(nsegments, device)
return tmp, tmp
else:
tmp_shrink = get_tmp_tensor(device)
tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device)
return tmp_shrink, tmp_expand
def lora_a_sgmv_cutlass(
x: torch.Tensor,
tmp: torch.Tensor,
wa_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
lora_rank: int,
) -> torch.Tensor:
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
else:
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
return v
def lora_b_sgmv_cutlass(
y: torch.Tensor,
v: torch.Tensor,
tmp: torch.Tensor,
wb_ptr: torch.Tensor,
s_start: torch.IntTensor,
s_end: torch.IntTensor,
layer_idx: int,
):
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
v: Shape: `[B, R]`. Temporary vector.
x: Shape: `[B, H1]`. Input vectors.
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
"""
def add_lora_a_bgmv(
v: torch.Tensor,
x: torch.Tensor,
wa_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
def add_lora_b_bgmv(
y: torch.Tensor,
v: torch.Tensor,
wb_T_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
):
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
def segmented_matmul(
y: torch.Tensor,
x: torch.Tensor,
w: List[torch.Tensor],
b: List[torch.Tensor],
s_start: torch.IntTensor,
s_end: torch.IntTensor,
):
for i in range(len(w)):
if s_end[i] - s_start[i] <= 0:
continue
xi = x[s_start[i] : s_end[i]]
wi = w[i]
bi = b[i]
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)