Enable multiple LoRa adapters (#2010)
* feat: first draft load multiple lora * feat: load weights within layer and refactor lora pass * fix: refactor and reduce lora math * feat: baseline impl single request multi lora support * feat: prefer lorax implementation and port loading logic * fix: prefer adapter_data and refactors * feat: perfer loraxs custom punica kernels and add mlp loras * fix: adjust batch for bgmv * fix: adjust adapter_segments logic when in batch * fix: refactor and move changes to v3 proto * fix: pass model_id for all flash causal lms * fix: pass model_id for all causal and seq2seq lms * fix: add model_id to model test * feat: add lora support to mistral and refactors * feat: prefer model id in request * fix: include rust code for adapter id * feat: bump launcher and add new lora docs * feat: support base model generation and refactors * fix: rename doc to retry ci build * feat: support if vlm models * fix: add adapter_data param and avoid missing layers * fix: add adapter_data param to phi and neox * fix: update all models forwards to include adapter_data * fix: add model_id to IdeficsCausalLM * Update lora.md Fixed a typo * Update lora.md Fixing spam image * fix: add lora kernel to dockerfile, support running without kernels and refactors * fix: avoid dockerfile conflict * fix: refactors and adjust flash llama lora logic * fix: skip llama test due to CI issue (temp) * fix: skip llama test CI (temp) 2 * fix: revert skips and prefer updated ci token for tests * fix: refactors and helpful comments * fix: add noop in TensorParallelAdapterRowLinear too * fix: refactor and move shard_lora_weights logic * fix: exit early if no adapter_data --------- Co-authored-by: Derek <datavistics@gmail.com>
This commit is contained in:
parent
a2a97b05d6
commit
04e1af94d7
10
Dockerfile
10
Dockerfile
|
@ -145,6 +145,13 @@ COPY server/marlin/ .
|
|||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
|
||||
# Build Lorax Punica kernels
|
||||
FROM kernel-builder as lorax-punica-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-lorax-punica Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
|
@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86
|
|||
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
# Copy build artifacts from marlin kernels builder
|
||||
COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
|||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
# CMD ["--json-output"]
|
||||
|
|
|
@ -157,6 +157,7 @@ async fn prefill(
|
|||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
|
|
@ -60,6 +60,9 @@
|
|||
- local: conceptual/speculation
|
||||
title: Speculation (Medusa, ngram)
|
||||
- local: conceptual/guidance
|
||||
title: How Guidance Works (via outlines)
|
||||
title: How Guidance Works (via outlines
|
||||
- local: conceptual/lora
|
||||
title: LoRA (Low-Rank Adaptation)
|
||||
|
||||
|
||||
title: Conceptual Guides
|
||||
|
|
|
@ -416,6 +416,14 @@ Options:
|
|||
[env: MAX_CLIENT_BATCH_SIZE=]
|
||||
[default: 4]
|
||||
|
||||
```
|
||||
## LORA_ADAPTERS
|
||||
```shell
|
||||
--lora-adapters <LORA_ADAPTERS>
|
||||
Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request
|
||||
|
||||
[env: LORA_ADAPTERS=]
|
||||
|
||||
```
|
||||
## HELP
|
||||
```shell
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# LoRA (Low-Rank Adaptation)
|
||||
|
||||
## What is LoRA?
|
||||
|
||||
LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.
|
||||
|
||||
LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.
|
||||
|
||||
## How is it used?
|
||||
|
||||
LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:
|
||||
|
||||
Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:
|
||||
|
||||
- fine-tuning a language model on a small dataset
|
||||
- fine-tuning a language model on a domain-specific dataset
|
||||
- fine-tuning a language model on a dataset with limited labels
|
||||
|
||||
## Optimizing Inference with LoRA
|
||||
|
||||
LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.
|
||||
|
||||
## Serving multiple LoRA adapters with TGI
|
||||
|
||||
Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.
|
||||
|
||||
In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.
|
||||
|
||||
Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.
|
||||
|
||||
### Specifying LoRA models
|
||||
|
||||
To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:
|
||||
|
||||
```bash
|
||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||
```
|
||||
|
||||
In the server logs, you will see the following message:
|
||||
|
||||
```txt
|
||||
Loading adapter weights into model: predibase/customer_support
|
||||
Loading adapter weights into model: predibase/dbpedia
|
||||
```
|
||||
|
||||
## Generate text
|
||||
|
||||
You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:
|
||||
|
||||
```json
|
||||
curl 127.0.0.1:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "Hello who are you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "predibase/customer_support"
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||
|
||||
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
|
@ -452,6 +452,11 @@ struct Args {
|
|||
/// Control the maximum number of inputs that a client can send in a single request
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
|
||||
/// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during
|
||||
/// startup that will be available to callers via the `adapter_id` field in a request.
|
||||
#[clap(long, env)]
|
||||
lora_adapters: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -485,6 +490,7 @@ fn shard_manager(
|
|||
max_total_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
max_input_tokens: usize,
|
||||
lora_adapters: Option<String>,
|
||||
otlp_endpoint: Option<String>,
|
||||
otlp_service_name: String,
|
||||
log_level: LevelFilter,
|
||||
|
@ -620,6 +626,11 @@ fn shard_manager(
|
|||
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
|
||||
}
|
||||
|
||||
// Lora Adapters
|
||||
if let Some(lora_adapters) = lora_adapters {
|
||||
envs.push(("LORA_ADAPTERS".into(), lora_adapters.into()));
|
||||
}
|
||||
|
||||
// If huggingface_hub_cache is some, pass it to the shard
|
||||
// Useful when running inside a docker container
|
||||
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
|
||||
|
@ -1060,6 +1071,7 @@ fn spawn_shards(
|
|||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
let max_batch_size = args.max_batch_size;
|
||||
let lora_adapters = args.lora_adapters.clone();
|
||||
thread::spawn(move || {
|
||||
shard_manager(
|
||||
model_id,
|
||||
|
@ -1085,6 +1097,7 @@ fn spawn_shards(
|
|||
max_total_tokens,
|
||||
max_batch_size,
|
||||
max_input_tokens,
|
||||
lora_adapters,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
max_log_level,
|
||||
|
|
|
@ -134,6 +134,8 @@ message Request {
|
|||
repeated uint32 blocks = 9;
|
||||
/// Paged attention slots
|
||||
repeated uint32 slots = 10;
|
||||
/// LORA adapter index
|
||||
optional string adapter_id = 11;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
|
|
|
@ -177,6 +177,7 @@ impl Client {
|
|||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
|
|
|
@ -244,6 +244,7 @@ impl Health for ShardedClient {
|
|||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
|
|
@ -429,6 +429,7 @@ mod tests {
|
|||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
|
|
@ -351,6 +351,7 @@ impl State {
|
|||
top_n_tokens: entry.request.top_n_tokens,
|
||||
blocks,
|
||||
slots,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
|
@ -491,6 +492,7 @@ mod tests {
|
|||
stop_sequences: vec![],
|
||||
},
|
||||
top_n_tokens: 0,
|
||||
adapter_id: None,
|
||||
},
|
||||
response_tx,
|
||||
span: info_span!("entry"),
|
||||
|
|
|
@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters {
|
|||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub grammar: Option<GrammarType>,
|
||||
|
||||
/// Lora adapter id
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
fn default_max_new_tokens() -> Option<u32> {
|
||||
|
@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters {
|
|||
seed: None,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
adapter_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -673,6 +673,7 @@ async fn completions(
|
|||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
..Default::default()
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
@ -1115,6 +1116,7 @@ async fn chat_completions(
|
|||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -202,6 +202,7 @@ impl Validation {
|
|||
decoder_input_details,
|
||||
top_n_tokens,
|
||||
grammar,
|
||||
adapter_id,
|
||||
..
|
||||
} = request.parameters;
|
||||
|
||||
|
@ -383,6 +384,7 @@ impl Validation {
|
|||
parameters,
|
||||
stopping_parameters,
|
||||
top_n_tokens,
|
||||
adapter_id,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest {
|
|||
pub parameters: ValidParameters,
|
||||
pub stopping_parameters: ValidStoppingParameters,
|
||||
pub top_n_tokens: u32,
|
||||
pub adapter_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
|
|
@ -4,6 +4,7 @@ include Makefile-vllm
|
|||
include Makefile-awq
|
||||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc
|
||||
|
||||
build-lorax-punica:
|
||||
if [ ! -d 'lorax-punica' ]; then \
|
||||
git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \
|
||||
fi
|
||||
cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit)
|
||||
cd lorax-punica && git submodule update --init --recursive
|
||||
cd lorax-punica/server/punica_kernels && python setup.py build
|
||||
|
||||
install-lorax-punica: build-lorax-punica
|
||||
cd lorax-punica/server/punica_kernels && python setup.py install
|
|
@ -17,7 +17,12 @@ def get_test_model():
|
|||
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
|
||||
|
||||
model = TestModel(
|
||||
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
|
||||
"test_model_id",
|
||||
torch.nn.Linear(1, 1),
|
||||
tokenizer,
|
||||
False,
|
||||
torch.float32,
|
||||
torch.device("cpu"),
|
||||
)
|
||||
return model
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/__init__.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchData,
|
||||
AdapterBatchMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AdapterBatchData",
|
||||
"AdapterBatchMetadata",
|
||||
]
|
|
@ -0,0 +1,44 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/config.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from text_generation_server.adapters.weights import AdapterWeights
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleMap:
|
||||
module_name: str
|
||||
module_weights: Dict[str, Tuple[torch.Tensor, str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig(ABC):
|
||||
base_model_name_or_path: str
|
||||
|
||||
@abstractmethod
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: ModuleMap,
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
pass
|
|
@ -0,0 +1,482 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/lora.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig as _LoraConfig
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
from text_generation_server.adapters.weights import (
|
||||
AdapterBatchMetadata,
|
||||
AdapterWeights,
|
||||
BatchAdapterWeights,
|
||||
)
|
||||
from text_generation_server.utils.sgmv import (
|
||||
BGMV_MAX_RANK,
|
||||
MAX_RANK_CUSTOM,
|
||||
get_tmp_tensors,
|
||||
orient_for_rank,
|
||||
pad_rank,
|
||||
use_cutlass_shrink,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.models.model import Model
|
||||
|
||||
|
||||
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
|
||||
block_size = size // world_size
|
||||
start = offset + rank * block_size
|
||||
stop = offset + (rank + 1) * block_size
|
||||
return start, stop
|
||||
|
||||
|
||||
def shard_on_dim(
|
||||
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
size = t.shape[dim]
|
||||
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
|
||||
|
||||
if dim == 0:
|
||||
tensor = t[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = t[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def shard_lora_weights(
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
split_dim: int,
|
||||
process_group: ProcessGroup,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
# [hidden_size, r]
|
||||
weights_a = [
|
||||
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
|
||||
]
|
||||
|
||||
# [r, hidden_size]
|
||||
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
|
||||
|
||||
return weights_a, weights_b
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraConfig(AdapterConfig):
|
||||
r: int
|
||||
target_modules: Optional[Union[List[str], str]]
|
||||
fan_in_fan_out: bool
|
||||
lora_alpha: int
|
||||
use_rslora: bool
|
||||
|
||||
def map_weights_for_model(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
weight_names: Tuple[str],
|
||||
) -> Tuple[ModuleMap, Set[str]]:
|
||||
adapter_weight_names = set()
|
||||
module_map = {}
|
||||
for weight_name in weight_names:
|
||||
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
|
||||
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
|
||||
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
|
||||
continue
|
||||
|
||||
module_map[weight_name] = {
|
||||
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
|
||||
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
|
||||
}
|
||||
adapter_weight_names.add(lora_a_name)
|
||||
adapter_weight_names.add(lora_b_name)
|
||||
return module_map, adapter_weight_names
|
||||
|
||||
def load_batched_adapter_weights(
|
||||
self,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
dynamic: bool,
|
||||
) -> Optional[AdapterWeights]:
|
||||
return LoraWeights.load(
|
||||
self,
|
||||
model,
|
||||
module_map,
|
||||
layer_type,
|
||||
unused_weight_names,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
|
||||
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
|
||||
return cls(
|
||||
base_model_name_or_path=hf_config.base_model_name_or_path,
|
||||
r=hf_config.r,
|
||||
target_modules=hf_config.target_modules,
|
||||
fan_in_fan_out=hf_config.fan_in_fan_out,
|
||||
lora_alpha=hf_config.lora_alpha,
|
||||
use_rslora=(
|
||||
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraWeights(AdapterWeights):
|
||||
"""LoRA weights for a single adapter merged across all layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights_a: List[torch.Tensor],
|
||||
weights_b: List[torch.Tensor],
|
||||
adapter_config: LoraConfig,
|
||||
):
|
||||
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
|
||||
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
|
||||
|
||||
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
|
||||
self._is_transposed = False
|
||||
|
||||
# [num_layers, hidden_size, r]
|
||||
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
|
||||
self._weights_a = torch.stack(weights_a)
|
||||
|
||||
# [num_layers, r, hidden_size]
|
||||
self._weights_b = torch.stack(weights_b)
|
||||
|
||||
self.adapter_config = adapter_config
|
||||
|
||||
@property
|
||||
def weights_a(self) -> torch.Tensor:
|
||||
if self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_a
|
||||
|
||||
@property
|
||||
def weights_b(self) -> torch.Tensor:
|
||||
if self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_b
|
||||
|
||||
@property
|
||||
def weights_a_t(self) -> torch.Tensor:
|
||||
if not self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_a
|
||||
|
||||
@property
|
||||
def weights_b_t(self) -> torch.Tensor:
|
||||
if not self._is_transposed:
|
||||
self._transpose_weights()
|
||||
return self._weights_b
|
||||
|
||||
def _transpose_weights(self):
|
||||
if self._use_cutlass_shrink:
|
||||
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
|
||||
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
|
||||
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
|
||||
self._is_transposed = not self._is_transposed
|
||||
|
||||
@classmethod
|
||||
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
|
||||
return [BatchLoraWeights]
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
config: LoraConfig,
|
||||
model: "Model",
|
||||
module_map: Dict[str, Dict],
|
||||
layer_type: str,
|
||||
unused_weight_names: Set[str],
|
||||
) -> Optional[AdapterWeights]:
|
||||
nlayers = model.get_num_layers_for_type(layer_type)
|
||||
lora_a_list = [None] * nlayers
|
||||
lora_b_list = [None] * nlayers
|
||||
|
||||
for layer_id in range(nlayers):
|
||||
key = (layer_id, layer_type)
|
||||
weight_name, layer = model.target_to_layer[key]
|
||||
base_weight = layer.base_layer.linear.weight
|
||||
base_device = base_weight.device
|
||||
|
||||
if weight_name not in module_map:
|
||||
# There is no LoRA weight for this layer type in the adapter
|
||||
return None
|
||||
|
||||
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
|
||||
lora_a = lora_a.to(base_device, model.dtype)
|
||||
|
||||
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
|
||||
lora_b = lora_b.to(base_device, model.dtype)
|
||||
|
||||
scale = get_scaling_factor(
|
||||
config.lora_alpha,
|
||||
config.r,
|
||||
uses_rslora=config.use_rslora,
|
||||
)
|
||||
|
||||
unused_weight_names.discard(lora_a_name)
|
||||
unused_weight_names.discard(lora_b_name)
|
||||
|
||||
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
|
||||
# (A * B) * C = A * (B * C)
|
||||
lora_a_list[layer_id] = lora_a.transpose(0, 1)
|
||||
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
|
||||
|
||||
# pad lora ranks to be compatible with sgmv
|
||||
lora_a_list = [
|
||||
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
|
||||
]
|
||||
lora_b_list = [
|
||||
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
|
||||
]
|
||||
|
||||
if lora_a_list:
|
||||
# update rank if it was padded
|
||||
padded_rank = lora_a_list[0].size(1)
|
||||
config.r = padded_rank
|
||||
|
||||
return LoraWeights(
|
||||
*shard_lora_weights(
|
||||
weights_a=lora_a_list,
|
||||
weights_b=lora_b_list,
|
||||
split_dim=0 if model.is_row_parallel(layer_type) else 1,
|
||||
process_group=model.process_group,
|
||||
),
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankSegments:
|
||||
rank: int
|
||||
|
||||
lora_a_ptr: torch.Tensor
|
||||
lora_b_ptr: torch.Tensor
|
||||
|
||||
# prefill (sgmv)
|
||||
tmp_shrink: torch.Tensor
|
||||
tmp_expand: torch.Tensor
|
||||
segment_starts: torch.Tensor
|
||||
segment_ends: torch.Tensor
|
||||
|
||||
# decode (bgmv)
|
||||
indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchLoraWeights(BatchAdapterWeights):
|
||||
lora_a: Dict[int, torch.Tensor]
|
||||
lora_b: Dict[int, torch.Tensor]
|
||||
adapter_index_configs: Dict[int, LoraConfig]
|
||||
rank_data: Dict[int, RankSegments]
|
||||
use_sgmv: bool
|
||||
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
return adapter_index in self.adapter_index_configs
|
||||
|
||||
def can_vectorize(self, pg: ProcessGroup) -> bool:
|
||||
return all(
|
||||
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
|
||||
for rank_data in self.rank_data.values()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def key(cls) -> str:
|
||||
return "lora"
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
self,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Optional["BatchLoraWeights"]:
|
||||
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
|
||||
adapter_weights = {
|
||||
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
|
||||
}
|
||||
if not adapter_weights:
|
||||
return None
|
||||
|
||||
first_weights = next(iter(adapter_weights.values()))
|
||||
device = first_weights.weights_a.device
|
||||
segment_indices = meta.segment_indices
|
||||
|
||||
lora_a = {
|
||||
idx: adapter_weights[idx].weights_a
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
lora_b = {
|
||||
idx: adapter_weights[idx].weights_b
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
|
||||
max_rank = max(
|
||||
(
|
||||
adapter_weights[idx].lora_a_r
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
),
|
||||
default=0,
|
||||
)
|
||||
|
||||
if prefill or max_rank > BGMV_MAX_RANK:
|
||||
use_sgmv = True
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_a.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
lora_b_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_b.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
use_sgmv = False
|
||||
lora_a_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_a_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
lora_b_ptr = torch.tensor(
|
||||
[
|
||||
(
|
||||
adapter_weights[idx].weights_b_t.data_ptr()
|
||||
if idx in adapter_weights
|
||||
else 0
|
||||
)
|
||||
for idx in segment_indices
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
adapter_index_configs = {
|
||||
idx: adapter_weights[idx].adapter_config
|
||||
for idx in segment_indices
|
||||
if idx in adapter_weights
|
||||
}
|
||||
|
||||
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
|
||||
|
||||
rank_indices = defaultdict(list)
|
||||
for segment_idx, adapter_idx in enumerate(segment_indices):
|
||||
if adapter_idx not in adapter_weights:
|
||||
continue
|
||||
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
|
||||
|
||||
if prefill_head_indices is not None:
|
||||
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
|
||||
for head_index in prefill_head_indices:
|
||||
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
|
||||
if head_index < meta.adapter_segments[j]:
|
||||
prefill_head_segment_ends[-1] += 1
|
||||
else:
|
||||
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
|
||||
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
|
||||
j += 1
|
||||
|
||||
rank_data = {}
|
||||
for rank, indices in rank_indices.items():
|
||||
tmp_shrink = None
|
||||
tmp_expand = None
|
||||
segment_starts = None
|
||||
segment_ends = None
|
||||
batch_indices = None
|
||||
|
||||
if use_sgmv:
|
||||
lora_a_ptr_indices = lora_a_ptr[indices]
|
||||
tmp_shrink, tmp_expand = get_tmp_tensors(
|
||||
lora_a_ptr_indices.size(0), rank, device
|
||||
)
|
||||
segment_starts = meta.adapter_segments[indices]
|
||||
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
|
||||
if prefill_head_indices is not None:
|
||||
for i, segment_index in enumerate(indices):
|
||||
segment_starts[i] = prefill_head_segment_starts[segment_index]
|
||||
segment_ends[i] = prefill_head_segment_ends[segment_index]
|
||||
else:
|
||||
rank_indices = set(indices)
|
||||
batch_indices = [
|
||||
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
|
||||
]
|
||||
batch_indices = [
|
||||
idx if idx in rank_indices else -1 for idx in batch_indices
|
||||
]
|
||||
batch_indices = torch.tensor(
|
||||
batch_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
rank_data[rank] = RankSegments(
|
||||
rank=rank,
|
||||
tmp_shrink=tmp_shrink,
|
||||
tmp_expand=tmp_expand,
|
||||
lora_a_ptr=lora_a_ptr[indices],
|
||||
lora_b_ptr=lora_b_ptr[indices],
|
||||
segment_starts=segment_starts,
|
||||
segment_ends=segment_ends,
|
||||
indices=batch_indices,
|
||||
)
|
||||
|
||||
return BatchLoraWeights(
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
adapter_index_configs=adapter_index_configs,
|
||||
rank_data=rank_data,
|
||||
use_sgmv=use_sgmv,
|
||||
)
|
||||
|
||||
|
||||
def get_scaling_factor(
|
||||
lora_alpha: int,
|
||||
r: int,
|
||||
uses_rslora: bool = False,
|
||||
) -> float:
|
||||
"""Computes the scaling factor for the lora weights."""
|
||||
if uses_rslora:
|
||||
return lora_alpha / (r**0.5)
|
||||
return lora_alpha / r
|
||||
|
||||
|
||||
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
|
||||
if hasattr(v, "lora_weights"):
|
||||
return v.lora_weights
|
||||
return v
|
|
@ -0,0 +1,158 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/adapters/weights.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from abc import ABC, abstractclassmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchMetadata:
|
||||
# [batch_size]
|
||||
adapter_indices: torch.Tensor
|
||||
|
||||
# [num_adapters]
|
||||
adapter_set: Set[int]
|
||||
|
||||
# [num_segments + 1]
|
||||
adapter_segments: torch.Tensor
|
||||
|
||||
# [num_segments]
|
||||
# maps from segment index to adapter index, i.e.:
|
||||
# segment_indices[s] == adapter_indices[i]
|
||||
segment_indices: List[int]
|
||||
|
||||
|
||||
class AdapterWeights(ABC):
|
||||
@abstractclassmethod
|
||||
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def speculative_tokens(self) -> int:
|
||||
return 0
|
||||
|
||||
|
||||
class BatchAdapterWeights(ABC):
|
||||
@abstractclassmethod
|
||||
def has_adapter(self, adapter_index: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def key(cls) -> str:
|
||||
pass
|
||||
|
||||
@abstractclassmethod
|
||||
def load(
|
||||
cls,
|
||||
adapter_weights: Dict[int, AdapterWeights],
|
||||
meta: "AdapterBatchMetadata",
|
||||
prefill: bool,
|
||||
prefill_head_indices: torch.Tensor,
|
||||
) -> Optional["BatchAdapterWeights"]:
|
||||
pass
|
||||
|
||||
|
||||
class LayerAdapterWeights:
|
||||
"""Adapter weights that apply to a particular layer."""
|
||||
|
||||
def __init__(self):
|
||||
self.adapter_weights: Dict[int, AdapterWeights] = {}
|
||||
|
||||
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
|
||||
self.adapter_weights[adapter_idx] = weights
|
||||
|
||||
def remove_adapter(self, adapter_idx: int):
|
||||
if adapter_idx not in self.adapter_weights:
|
||||
return
|
||||
del self.adapter_weights[adapter_idx]
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
adapter_weights.speculative_tokens
|
||||
for adapter_weights in self.adapter_weights.values()
|
||||
)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.adapter_weights) == 0
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
meta: AdapterBatchMetadata,
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> Dict[str, BatchAdapterWeights]:
|
||||
# bucket adapters by batch class
|
||||
adapter_batch_types: Dict[
|
||||
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
|
||||
] = defaultdict(dict)
|
||||
for adapter_index, adapter_weights in self.adapter_weights.items():
|
||||
for batch_type in adapter_weights.get_batch_types():
|
||||
adapter_batch_types[batch_type][adapter_index] = adapter_weights
|
||||
|
||||
batch_data = {}
|
||||
for batch_type, adapter_weights in adapter_batch_types.items():
|
||||
batched_weights = batch_type.load(
|
||||
adapter_weights, meta, prefill, prefill_head_indices
|
||||
)
|
||||
if batched_weights is not None:
|
||||
batch_data[batch_type.key()] = batched_weights
|
||||
return batch_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterBatchData:
|
||||
meta: AdapterBatchMetadata
|
||||
|
||||
# layer type -> adapter type -> batch weight data
|
||||
data: Dict[str, Dict[str, BatchAdapterWeights]]
|
||||
|
||||
prefill: bool
|
||||
|
||||
@staticmethod
|
||||
def from_meta(
|
||||
meta: AdapterBatchMetadata,
|
||||
weights: Dict[str, LayerAdapterWeights],
|
||||
prefill: bool,
|
||||
prefill_head_indices: Optional[torch.Tensor],
|
||||
) -> "AdapterBatchData":
|
||||
data = {}
|
||||
for k, v in weights.items():
|
||||
if v.is_empty():
|
||||
continue
|
||||
data[k] = v.get_data(
|
||||
meta, prefill, prefill_head_indices if k == "lm_head" else None
|
||||
)
|
||||
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
|
||||
|
||||
def ranks(self) -> Set[int]:
|
||||
# TODO(travis): refactor to be less coupled to lora implementation
|
||||
ranks = set()
|
||||
for layer_data in self.data.values():
|
||||
lora_data = layer_data.get("lora")
|
||||
if lora_data is None:
|
||||
continue
|
||||
|
||||
for rank_data in lora_data.rank_data.values():
|
||||
ranks.add(rank_data.rank)
|
||||
|
||||
return ranks
|
||||
|
||||
def layer_names(self) -> Set[str]:
|
||||
return set(self.data.keys())
|
||||
|
||||
def adapter_keys(self) -> Set[str]:
|
||||
adapter_keys = set()
|
||||
for layer_data in self.data.values():
|
||||
adapter_keys.update(layer_data.keys())
|
||||
return adapter_keys
|
||||
|
||||
@property
|
||||
def max_rank(self) -> int:
|
||||
ranks = self.ranks()
|
||||
return max(ranks) if len(ranks) > 0 else 0
|
|
@ -79,6 +79,18 @@ def serve(
|
|||
if otlp_endpoint is not None:
|
||||
setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
|
||||
|
||||
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
|
||||
|
||||
# split on comma and strip whitespace
|
||||
lora_adapter_ids = (
|
||||
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
logger.warning(
|
||||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||
)
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
|
@ -93,6 +105,7 @@ def serve(
|
|||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
@ -113,6 +126,7 @@ def download_weights(
|
|||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
merge_lora: bool = False,
|
||||
):
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
|
@ -143,18 +157,28 @@ def download_weights(
|
|||
) is not None
|
||||
|
||||
if not is_local_model:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="adapter_config.json"
|
||||
)
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local_model = True
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
# TODO: maybe reverse the default value of merge_lora?
|
||||
# currently by default we don't merge the weights with the base model
|
||||
if merge_lora:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="adapter_config.json"
|
||||
)
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local_model = True
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
utils.peft.download_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import json
|
||||
|
|
|
@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead
|
|||
# Just to add the `load` methods.
|
||||
from text_generation_server.layers.layernorm import load_layer_norm
|
||||
from text_generation_server.layers.conv import load_conv2d
|
||||
|
||||
from text_generation_server.layers.lora import (
|
||||
LoraLinear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,286 @@
|
|||
import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from accelerate import init_empty_weights
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from text_generation_server.utils.sgmv import (
|
||||
add_lora_a_bgmv,
|
||||
add_lora_b_bgmv,
|
||||
has_sgmv,
|
||||
lora_a_sgmv_cutlass,
|
||||
lora_b_sgmv_cutlass,
|
||||
orient_for_rank,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.adapters.lora import BatchLoraWeights
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
def __init__(
|
||||
self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup
|
||||
):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.layer_id = layer_id
|
||||
self.process_group = process_group
|
||||
|
||||
def forward_layer_type(
|
||||
self,
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
adapter_data: "AdapterBatchData",
|
||||
layer_type: str,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get("lora") if data is not None else None
|
||||
)
|
||||
|
||||
if has_sgmv() and data is not None and data.can_vectorize(self.process_group):
|
||||
# In tensor-parallel configurations, each GPU processes a specific segment of the output.
|
||||
# The 'result' tensor represents the full output, which can vary in size based on
|
||||
# the layer type (e.g., attention vs. feed-forward layers). We define the current
|
||||
# segment using start_idx and end_idx. If the segment size doesn't match this GPU's
|
||||
# slice of 'result', we create a zero tensor of the correct size for LoRA computation.
|
||||
# This approach ensures accurate LoRA application across various layer sizes and
|
||||
# configurations, adapting to different model architectures and parallelization strategies.
|
||||
#
|
||||
# Example scenarios where this is necessary:
|
||||
# 1. The adapter's size doesn't evenly divide across GPUs.
|
||||
# 2. We're processing the last segment which might be smaller.
|
||||
# 3. Different projection layers (q, k, v) have different sizes.
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
proj = torch.zeros_like(result[:, start_idx:end_idx])
|
||||
else:
|
||||
proj = result
|
||||
|
||||
for r, rank_segments in data.rank_data.items():
|
||||
lora_a_ptr = rank_segments.lora_a_ptr
|
||||
lora_b_ptr = rank_segments.lora_b_ptr
|
||||
|
||||
if lora_a_ptr is None or lora_b_ptr is None:
|
||||
raise ValueError("LoRA data is missing")
|
||||
|
||||
if data.use_sgmv:
|
||||
# Use SGMV for prefill
|
||||
v = lora_a_sgmv_cutlass(
|
||||
input,
|
||||
rank_segments.tmp_shrink,
|
||||
lora_a_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
r,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
lora_b_sgmv_cutlass(
|
||||
proj,
|
||||
v,
|
||||
rank_segments.tmp_expand,
|
||||
lora_b_ptr,
|
||||
rank_segments.segment_starts,
|
||||
rank_segments.segment_ends,
|
||||
self.layer_id,
|
||||
)
|
||||
else:
|
||||
# Use BGMV for decode
|
||||
v = torch.zeros(
|
||||
(input.size(0), r), dtype=input.dtype, device=input.device
|
||||
)
|
||||
# TODO: error with [-1, 0], but not [0, -1]
|
||||
add_lora_a_bgmv(
|
||||
v,
|
||||
input,
|
||||
lora_a_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
v = self.collect_lora_a(v)
|
||||
|
||||
add_lora_b_bgmv(
|
||||
proj,
|
||||
v,
|
||||
lora_b_ptr,
|
||||
rank_segments.indices,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if end_idx - start_idx != result.shape[1]:
|
||||
result[:, start_idx:end_idx] += proj
|
||||
else:
|
||||
for adapter_index in adapter_data.meta.adapter_set:
|
||||
if data is not None and data.has_adapter(adapter_index):
|
||||
adapter_mask = (
|
||||
(adapter_data.meta.adapter_indices == adapter_index)
|
||||
.to(input.dtype)
|
||||
.view(-1, 1)
|
||||
)
|
||||
layer_result = self.forward_lora(
|
||||
input, data, adapter_index, adapter_mask
|
||||
)
|
||||
result[:, start_idx:end_idx] += layer_result
|
||||
|
||||
return result
|
||||
|
||||
def forward_lora(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
data: "BatchLoraWeights",
|
||||
adapter_index: int,
|
||||
adapter_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
lora_a = data.lora_a[adapter_index][self.layer_id, :, :]
|
||||
lora_b = data.lora_b[adapter_index][self.layer_id, :, :]
|
||||
|
||||
lora_a = orient_for_rank(lora_a, lora_b.size(0))
|
||||
|
||||
a_out = input @ lora_a
|
||||
if self.process_group.size() > 1:
|
||||
a_out = self.collect_lora_a(a_out)
|
||||
|
||||
result = (a_out @ lora_b) * adapter_mask
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("Implemented in subclasses")
|
||||
|
||||
|
||||
class TensorParallelMultiAdapterLinear(LoraLinear):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_names = layer_names
|
||||
self.sizes = sizes
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
base_layer: nn.Module,
|
||||
layer_id: int,
|
||||
layer_names: List[str],
|
||||
sizes: List[int],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
return TensorParallelMultiAdapterLinear(
|
||||
base_layer, layer_id, layer_names, sizes, process_group
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
# noop if no layer names are provided (e.g. for models without adapters)
|
||||
if self.layer_names is None:
|
||||
return result
|
||||
|
||||
# handle models like Bloom that have inputs of shape
|
||||
# (batch_size, sequence_length, hidden_size)
|
||||
# we need to reshape them to (batch_size * sequence_length, hidden_size)
|
||||
# for the LoRA computation, then reshape back
|
||||
prev_shape = result.shape
|
||||
is_3d = len(input.shape) >= 3
|
||||
if is_3d:
|
||||
input = input.reshape(-1, input.shape[-1])
|
||||
result = result.reshape(-1, result.shape[-1])
|
||||
|
||||
offset = 0
|
||||
for i, layer_name in enumerate(self.layer_names):
|
||||
start_idx = offset // self.process_group.size()
|
||||
# The 'sizes' parameter is essential in tensor-parallel setups for handling multiple
|
||||
# projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It
|
||||
# ensures correct slicing of the result tensor, accommodating variations like grouped-query
|
||||
# attention where k_proj and v_proj differ from q_proj. This allows precise application of
|
||||
# LoRA adapters to each sub-component of the multi-head attention mechanism, managing the
|
||||
# different projection sizes across layers and model architectures.
|
||||
if self.sizes is not None:
|
||||
offset += self.sizes[i]
|
||||
end_idx = offset // self.process_group.size()
|
||||
else:
|
||||
end_idx = result.shape[1]
|
||||
|
||||
result = self.forward_layer_type(
|
||||
result, input, adapter_data, layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
if is_3d:
|
||||
result = result.reshape(prev_shape)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.
|
||||
# We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-gather for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
gathered_tensors = [
|
||||
torch.empty_like(a_out) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(gathered_tensors, a_out)
|
||||
return torch.cat(gathered_tensors, dim=1)
|
||||
|
||||
|
||||
class TensorParallelAdapterRowLinear(LoraLinear):
|
||||
def __init__(self, base_layer, layer_id, layer_name, process_group):
|
||||
super().__init__(base_layer, layer_id, process_group)
|
||||
self.layer_name = layer_name
|
||||
|
||||
@classmethod
|
||||
def load(cls, base_layer, layer_id, layer_name, process_group):
|
||||
return cls(base_layer, layer_id, layer_name, process_group)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, adapter_data: "AdapterBatchData"
|
||||
) -> torch.Tensor:
|
||||
result = self.base_layer(input)
|
||||
|
||||
if self.layer_name is None:
|
||||
return result
|
||||
|
||||
# Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285
|
||||
stride = result.shape[-1] // self.process_group.size()
|
||||
start_idx = self.process_group.rank() * stride
|
||||
end_idx = (self.process_group.rank() + 1) * stride
|
||||
|
||||
self.forward_layer_type(
|
||||
result, input, adapter_data, self.layer_name, start_idx, end_idx
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:
|
||||
# Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.
|
||||
# We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.
|
||||
#
|
||||
# TODO(travis): this is not very efficient as we do an all-reduce for every adapter,
|
||||
# instead we could pre-allocate a (B, a, r) tensor for all adapters with the same
|
||||
# rank, compute `a_out` on each, and then slice them into the buffer as shown here:
|
||||
# https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609
|
||||
torch.distributed.all_reduce(a_out, group=self.process_group)
|
||||
return a_out
|
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
|
@ -253,6 +253,7 @@ for data in ModelType:
|
|||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
|
@ -595,6 +596,7 @@ def get_model(
|
|||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
|
|
|
@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -538,6 +538,7 @@ class CausalLM(Model):
|
|||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
|
|
@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
|
|
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
|||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
@ -51,43 +53,61 @@ if SYSTEM == "rocm":
|
|||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = None
|
||||
prefixes = None
|
||||
|
||||
# if specific model type, load the correct attention
|
||||
if config.model_type == "phi3":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
prefix = f"{prefix}.qkv_proj"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
elif config.model_type == "baichuan":
|
||||
return TensorParallelColumnLinear.load_qkv(
|
||||
prefix = f"{prefix}.W_pack"
|
||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.W_pack",
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# otherwise, load the default attention based on the number of heads
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer=base_layer,
|
||||
layer_id=layer_id,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
|
@ -121,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -145,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
|
@ -190,11 +220,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
|
@ -209,29 +241,54 @@ class LlamaMLP(nn.Module):
|
|||
),
|
||||
)
|
||||
)
|
||||
prefixes = None
|
||||
sizes = None
|
||||
|
||||
# Fuse gate and up proj
|
||||
bias = getattr(config, "mlp_bias", False)
|
||||
if config.model_type == "phi3":
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_gate_up(
|
||||
config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
prefixes = [f"gate_proj", f"up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
@ -239,7 +296,7 @@ class LlamaMLP(nn.Module):
|
|||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
|
@ -253,20 +310,27 @@ class LlamaMLP(nn.Module):
|
|||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
|
@ -289,6 +353,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -303,6 +368,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
|
@ -310,7 +376,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
@ -325,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
|
@ -360,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
|
@ -382,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
@ -423,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
@ -436,6 +506,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
|
|||
|
||||
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
|
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
|
|||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
|
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
|
|||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||
query_key_value,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
layer_id,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
|
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
|
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
|
|||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
layer_id,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
layer_id,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
|
|||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
|
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
|
|||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.mlp = MistralMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
)
|
||||
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
|
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
|
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
|
|||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
|
|||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
|
|||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
|
|
@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids,
|
||||
|
|
|
@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||
# Unused here
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
|
|
@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
|
|
@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
|
|
|
@ -483,6 +483,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
|
|
|
@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
|
|
@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
|
|
|
@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
|
|
|
@ -13,6 +13,7 @@ from opentelemetry import trace
|
|||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
|||
import text_generation_server.models.globals as tgi_globals
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
empty_cache,
|
||||
|
@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens: List[int]
|
||||
top_n_tokens_tensor: torch.Tensor
|
||||
|
||||
# Adapter metadata for each request
|
||||
adapter_meta: AdapterBatchMetadata
|
||||
|
||||
# Number of blocks in this batch
|
||||
num_blocks: int
|
||||
# Maximum number of blocks
|
||||
|
@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
adapter_indices_list = []
|
||||
adapter_set = set()
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
|
@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch):
|
|||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(torch.full((input_length,), adapter_index))
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
speculative_length = get_speculate()
|
||||
|
@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch):
|
|||
max_length, input_length + max_new_tokens + speculative_length
|
||||
)
|
||||
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
|
@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch):
|
|||
input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
|
@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
speculative_ids=None,
|
||||
)
|
||||
|
||||
|
@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
adapter_set = set()
|
||||
|
||||
num_blocks = 0
|
||||
max_blocks = 0
|
||||
|
@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(
|
||||
self.requests[idx].adapter_id, 0
|
||||
)
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
|
@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch):
|
|||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
|
@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch):
|
|||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||
total_batch_size,
|
||||
)
|
||||
total_indices_size = sum(
|
||||
b.adapter_meta.adapter_indices.shape[0] for b in batches
|
||||
)
|
||||
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
|
||||
total_indices_size
|
||||
)
|
||||
adapter_set = set()
|
||||
adapter_segment_builder = SegmentConcatBuilder()
|
||||
|
||||
start_slots = []
|
||||
block_tables = []
|
||||
|
@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Cumulative length
|
||||
cumulative_batch_size = 0
|
||||
cumulative_slots = 0
|
||||
cumulative_adapter_indices_size = 0
|
||||
|
||||
for i, batch in enumerate(batches):
|
||||
requests.extend(batch.requests)
|
||||
|
@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
# Copy over adapter indices
|
||||
adapter_start_index = cumulative_adapter_indices_size
|
||||
adapter_end_index = (
|
||||
cumulative_adapter_indices_size
|
||||
+ batch.adapter_meta.adapter_indices.shape[0]
|
||||
)
|
||||
adapter_indices[adapter_start_index:adapter_end_index] = (
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
||||
)
|
||||
|
||||
all_input_ids_tensor[
|
||||
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
|
||||
] = batch.all_input_ids_tensor[:, :max_length]
|
||||
|
@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch):
|
|||
else None
|
||||
)
|
||||
|
||||
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
|
||||
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
|
@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch):
|
|||
num_blocks=num_blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_segment_indices,
|
||||
),
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch):
|
|||
class FlashCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_layers: int,
|
||||
|
@ -738,6 +816,7 @@ class FlashCausalLM(Model):
|
|||
self.kv_cache = []
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
|
@ -895,12 +974,13 @@ class FlashCausalLM(Model):
|
|||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
|
||||
batch_num_blocks = batch.num_blocks if batch is not None else 0
|
||||
|
||||
num_blocks = (
|
||||
# Leave 5% for some wiggle room
|
||||
int((free_memory * 0.95) // total_cache_size)
|
||||
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ batch.num_blocks
|
||||
+ batch_num_blocks
|
||||
)
|
||||
|
||||
del batch
|
||||
|
@ -1001,7 +1081,7 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
|
@ -1080,6 +1160,7 @@ class FlashCausalLM(Model):
|
|||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
|
@ -1116,7 +1197,34 @@ class FlashCausalLM(Model):
|
|||
prefill = batch.cu_seqlen_prefill is not None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
|
||||
out, speculative_logits = self.forward(batch)
|
||||
# Update adapter indices for speculative tokens (if present)
|
||||
adapter_meta = batch.adapter_meta
|
||||
if batch.speculative_ids is not None:
|
||||
B, speculative_length = batch.speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
adapter_indices = (
|
||||
adapter_meta.adapter_indices.unsqueeze(-1)
|
||||
.expand(B, new_length)
|
||||
.reshape(-1)
|
||||
)
|
||||
adapter_segments = adapter_meta.adapter_segments * new_length
|
||||
adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_meta.adapter_set,
|
||||
adapter_segments=adapter_segments,
|
||||
segment_indices=adapter_meta.segment_indices,
|
||||
)
|
||||
|
||||
# Assign pointers to adapter weights
|
||||
# TODO(travis): don't update this if indices haven't changed
|
||||
adapter_data = AdapterBatchData.from_meta(
|
||||
adapter_meta,
|
||||
self.layer_to_adapter_weights,
|
||||
prefill,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
|
||||
out, speculative_logits = self.forward(batch, adapter_data)
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
|
@ -1128,8 +1236,13 @@ class FlashCausalLM(Model):
|
|||
if prefill_logprobs
|
||||
else speculative_logits
|
||||
)
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
||||
len(batch)
|
||||
)
|
||||
|
||||
else:
|
||||
next_token_logits = out
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices
|
||||
|
||||
speculate = get_speculate()
|
||||
(
|
||||
|
@ -1195,6 +1308,12 @@ class FlashCausalLM(Model):
|
|||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
||||
# Initialize adapter indices
|
||||
# In decode, we only have one token per row in the batch, so grab last index
|
||||
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
|
||||
end_index - 1
|
||||
]
|
||||
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.input_ids to prefill_token_indices
|
||||
if prefill_logprobs:
|
||||
|
@ -1220,6 +1339,16 @@ class FlashCausalLM(Model):
|
|||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.input_lengths_tensor += accepted_ids
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
if prefill:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
adapter_segments,
|
||||
dtype=torch.int32,
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
|
|
|
@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCohere, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashDbrx, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM):
|
|||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGPT2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
|
@ -13,12 +14,24 @@ from text_generation_server.utils import (
|
|||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
hub,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
|
@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM):
|
|||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
lora_adapter_ids: Optional[list] = [],
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM):
|
|||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.distributed
|
|||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
||||
|
@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=num_layers,
|
||||
|
@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
model.model.head_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
|
|
|
@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashNeoXSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
|
|
|
@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashPhi, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
|
|
|
@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
|
|
|
@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
|
|
|
@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import os
|
||||
from loguru import logger
|
||||
from typing import Dict
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
|
@ -32,3 +33,14 @@ MODEL_ID = None
|
|||
def set_model_id(model_id: str):
|
||||
global MODEL_ID
|
||||
MODEL_ID = model_id
|
||||
|
||||
|
||||
# NOTE: eventually we should move this into the router and pass back the
|
||||
# index in all cases.
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX: Dict[str, int] = None
|
||||
|
||||
|
||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
global ADAPTER_TO_INDEX
|
||||
ADAPTER_TO_INDEX = adapter_to_index
|
||||
|
|
|
@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(IdeficsCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -634,6 +634,7 @@ class IdeficsCausalLM(Model):
|
|||
tokenizer.add_special_tokens({"pad_token": "<unk>"})
|
||||
|
||||
super(IdeficsCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -453,6 +453,7 @@ class Mamba(Model):
|
|||
model = MambaModel(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Mamba, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -2,12 +2,24 @@ import inspect
|
|||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
from text_generation_server.utils.adapter import (
|
||||
load_and_merge_adapters,
|
||||
AdapterParameters,
|
||||
AdapterSource,
|
||||
)
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
@ -15,6 +27,7 @@ B = TypeVar("B", bound=Batch)
|
|||
class Model(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
requires_padding: bool,
|
||||
|
@ -24,7 +37,9 @@ class Model(ABC):
|
|||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.model = model.eval()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
|
@ -42,6 +57,13 @@ class Model(ABC):
|
|||
self.world_size = world_size
|
||||
self.sliding_window = sliding_window if sliding_window != -1 else None
|
||||
|
||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||
LayerAdapterWeights
|
||||
)
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
self.loaded_adapters = set()
|
||||
self.static_adapter_id = adapter_id
|
||||
|
||||
if speculate is None:
|
||||
speculate = get_speculate()
|
||||
self.speculate = speculate
|
||||
|
@ -119,3 +141,136 @@ class Model(ABC):
|
|||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return False
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 0
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_speculative_tokens(self) -> int:
|
||||
return max(
|
||||
[
|
||||
weights.max_speculative_tokens
|
||||
for weights in self.layer_to_adapter_weights.values()
|
||||
],
|
||||
default=0,
|
||||
)
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
api_token: str,
|
||||
dynamic: bool = True,
|
||||
):
|
||||
"""Loads adapter weights from disk / host memory on the GPU.
|
||||
|
||||
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
|
||||
into model. Otherwise, the adapter weights are applied during the forward
|
||||
pass and stored separately from the base model parameters.
|
||||
"""
|
||||
if adapter_index in self.loaded_adapters:
|
||||
# Adapter already loaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if dynamic and not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}"
|
||||
)
|
||||
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
|
||||
(
|
||||
module_map,
|
||||
adapter_config,
|
||||
adapter_weight_names,
|
||||
adapter_tokenizer,
|
||||
) = load_and_merge_adapters(
|
||||
self.model_id,
|
||||
adapter_parameters,
|
||||
adapter_source,
|
||||
adapter_index,
|
||||
weight_names,
|
||||
api_token,
|
||||
False,
|
||||
)
|
||||
|
||||
unused_weight_names = adapter_weight_names.copy()
|
||||
for layer_name in self.adapter_layers:
|
||||
adapter_weights = adapter_config.load_batched_adapter_weights(
|
||||
self,
|
||||
module_map,
|
||||
layer_name,
|
||||
unused_weight_names,
|
||||
dynamic,
|
||||
)
|
||||
|
||||
if adapter_weights is None:
|
||||
continue
|
||||
|
||||
layer_weights = self.layer_to_adapter_weights[layer_name]
|
||||
layer_weights.add_adapter(adapter_index, adapter_weights)
|
||||
|
||||
if len(unused_weight_names) > 0:
|
||||
logger.warning(
|
||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
||||
)
|
||||
|
||||
if adapter_tokenizer is not None:
|
||||
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
||||
|
||||
self.loaded_adapters.add(adapter_index)
|
||||
|
||||
def offload_adapter(
|
||||
self,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: AdapterSource,
|
||||
adapter_index: int,
|
||||
):
|
||||
"""Offloads the adapter weights from GPU to CPU or disk."""
|
||||
if adapter_index not in self.loaded_adapters:
|
||||
# Adapter already offloaded
|
||||
return
|
||||
|
||||
if not self.supports_adapter_loading:
|
||||
raise ValueError("This model does not support adapter loading.")
|
||||
|
||||
if not self.dynamic_adapter_loading_enabled:
|
||||
raise ValueError(
|
||||
f"This model was initialized with the adapter {self.static_adapter_id} "
|
||||
f"and therefore does not support dynamic adapter loading. "
|
||||
f"Please initialize a new model instance from the base model in "
|
||||
f"order to use the dynamic adapter loading feature."
|
||||
)
|
||||
|
||||
for layer_name in self.adapter_layers:
|
||||
if layer_name in self.layer_to_adapter_weights:
|
||||
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)
|
||||
|
||||
self.loaded_adapters.remove(adapter_index)
|
||||
|
|
|
@ -90,6 +90,7 @@ class MPTSharded(CausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
|
|
|
@ -63,6 +63,7 @@ class OPTSharded(CausalLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -60,6 +60,7 @@ class Phi(CausalLM):
|
|||
model = PhiForCausalLM(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -62,6 +62,7 @@ class RW(CausalLM):
|
|||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -62,6 +62,7 @@ class SantaCoder(CausalLM):
|
|||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -575,6 +575,7 @@ class Seq2SeqLM(Model):
|
|||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -222,7 +222,9 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
return VlmCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self, batch: VlmCausalLMBatch
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
|
|
|
@ -29,7 +29,10 @@ except (ImportError, NotImplementedError):
|
|||
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id
|
||||
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import (
|
||||
AdapterParameters,
|
||||
)
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
|
@ -192,6 +195,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
|
||||
def serve(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool,
|
||||
quantize: Optional[str],
|
||||
|
@ -203,6 +207,7 @@ def serve(
|
|||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
lora_adapter_ids: Optional[List[str]],
|
||||
revision: Optional[str],
|
||||
sharded: bool = False,
|
||||
quantize: Optional[str] = None,
|
||||
|
@ -211,6 +216,7 @@ def serve(
|
|||
trust_remote_code: bool = False,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
adapter_to_index = {}
|
||||
if sharded:
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
|
@ -224,6 +230,7 @@ def serve(
|
|||
try:
|
||||
model = get_model(
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
|
@ -232,10 +239,33 @@ def serve(
|
|||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
if len(lora_adapter_ids) > 0:
|
||||
for index, adapter_id in enumerate(lora_adapter_ids):
|
||||
# TODO: improve non merged adapter loading and long term
|
||||
# improve adapter loading as a whole
|
||||
adapter_parameters = AdapterParameters(
|
||||
adapter_ids=[adapter_id],
|
||||
weights=None, # will be set to 1
|
||||
merge_strategy=0,
|
||||
density=1.0,
|
||||
majority_sign_method=0,
|
||||
)
|
||||
adapter_index = index + 1
|
||||
adapter_to_index[adapter_id] = adapter_index
|
||||
model.load_adapter(
|
||||
adapter_parameters,
|
||||
None, # adapter_source
|
||||
adapter_index,
|
||||
None, # api_token
|
||||
False, # dynamic
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
set_adapter_to_index(adapter_to_index)
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
|
@ -266,6 +296,13 @@ def serve(
|
|||
set_model_id(model_id)
|
||||
asyncio.run(
|
||||
serve_inner(
|
||||
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
|
||||
model_id,
|
||||
lora_adapter_ids,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/adapter.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Set, Tuple
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils.merges.strategies import merge_adapters
|
||||
|
||||
from text_generation_server.utils import hub
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
|
||||
|
||||
|
||||
BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParameters:
|
||||
adapter_ids: Tuple[str]
|
||||
weights: Tuple[float]
|
||||
merge_strategy: NotImplemented
|
||||
density: float
|
||||
majority_sign_method: NotImplemented
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterSource:
|
||||
adapter_id: str
|
||||
model_id: str
|
||||
revision: str
|
||||
|
||||
|
||||
def load_and_merge_adapters(
|
||||
model_id: str,
|
||||
adapter_parameters: AdapterParameters,
|
||||
adapter_source: str,
|
||||
adapter_index: int,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
if len(adapter_parameters.adapter_ids) == 1:
|
||||
return load_module_map(
|
||||
model_id,
|
||||
adapter_parameters.adapter_ids[0],
|
||||
adapter_source,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
|
||||
adapter_params = AdapterParametersContainer(
|
||||
adapter_parameters, adapter_source, adapter_index
|
||||
)
|
||||
return _load_and_merge(
|
||||
model_id, adapter_params, weight_names, api_token, trust_remote_code
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterParametersContainer:
|
||||
adapter_parameters: AdapterParameters
|
||||
adapter_source: str
|
||||
adapter_index: int
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.adapter_index
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _load_and_merge(
|
||||
model_id: str,
|
||||
adapter_params: AdapterParametersContainer,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
params = adapter_params.adapter_parameters
|
||||
|
||||
adapters_to_merge = []
|
||||
merged_weight_names = set()
|
||||
tokenizer = None
|
||||
for adapter_id in params.adapter_ids:
|
||||
if adapter_id == BASE_MODEL_ADAPTER_ID:
|
||||
raise ValueError("Base model adapter cannot be merged.")
|
||||
|
||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||
load_module_map(
|
||||
model_id,
|
||||
adapter_id,
|
||||
adapter_params.adapter_source,
|
||||
weight_names,
|
||||
api_token,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
||||
adapters_to_merge.append((module_map, adapter_config))
|
||||
merged_weight_names = merged_weight_names.union(adapter_weight_names)
|
||||
if tokenizer is None:
|
||||
tokenizer = adapter_tokenizer
|
||||
|
||||
if len(adapters_to_merge) == 0:
|
||||
raise ValueError("No adapters to merge.")
|
||||
|
||||
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
|
||||
return module_map, adapter_config, merged_weight_names, tokenizer
|
||||
|
||||
|
||||
def check_architectures(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_config: "AdapterConfig",
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
try:
|
||||
if not adapter_config.base_model_name_or_path:
|
||||
# Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
|
||||
return
|
||||
|
||||
expected_config = AutoConfig.from_pretrained(
|
||||
model_id, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
|
||||
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
|
||||
)
|
||||
return
|
||||
|
||||
if model_config.architectures == expected_config.architectures:
|
||||
warnings.warn(
|
||||
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
|
||||
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
||||
)
|
||||
else:
|
||||
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
|
||||
raise ValueError(
|
||||
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
|
||||
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
|
||||
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def load_module_map(
|
||||
model_id: str,
|
||||
adapter_id: str,
|
||||
adapter_source: str,
|
||||
weight_names: Tuple[str],
|
||||
api_token: str,
|
||||
trust_remote_code: bool = False,
|
||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||
revision = "main"
|
||||
|
||||
adapter_config = LoraConfig.load(adapter_id, api_token)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||
|
||||
adapter_filenames = hub._cached_adapter_weight_files(
|
||||
adapter_id, revision=revision, extension=".safetensors"
|
||||
)
|
||||
|
||||
try:
|
||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||
adapter_config.config_path,
|
||||
token=api_token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
# Adapter does not have a tokenizer, so fallback to base model tokenizer
|
||||
adapter_tokenizer = None
|
||||
|
||||
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||
adapter_weights = {}
|
||||
for filename in adapter_filenames:
|
||||
adapter_weights.update(load_file(filename))
|
||||
|
||||
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
|
||||
module_map, adapter_weight_names = adapter_config.map_weights_for_model(
|
||||
adapter_weights, weight_names
|
||||
)
|
||||
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer
|
|
@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||
|
||||
|
||||
def _cached_adapter_weight_files(
|
||||
adapter_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
"""Guess weight files from the cached revision snapshot directory"""
|
||||
d = _get_cached_revision_directory(adapter_id, revision)
|
||||
if not d:
|
||||
return []
|
||||
filenames = _adapter_weight_files_from_dir(d, extension)
|
||||
return filenames
|
||||
|
||||
|
||||
def _cached_weight_files(
|
||||
model_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
|
@ -60,6 +71,33 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||
return filenames
|
||||
|
||||
|
||||
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_files_from_dir, that's also what is done there
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "training" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_files_from_dir, that's also what is done there
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _get_cached_revision_directory(
|
||||
model_id: str, revision: Optional[str]
|
||||
) -> Optional[Path]:
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
import copy
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AdapterParameters:
|
||||
def __init__(
|
||||
self, adapter_ids, weights, merge_strategy, density, majority_sign_method
|
||||
):
|
||||
self.adapter_ids = adapter_ids
|
||||
self.weights = weights
|
||||
self.merge_strategy = merge_strategy
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
|
||||
from text_generation_server.utils.merges.utils import (
|
||||
calculate_majority_sign_mask,
|
||||
disjoint_merge,
|
||||
prune,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from text_generation_server.adapters.lora import LoraConfig
|
||||
from text_generation_server.utils.adapter import ModuleMap
|
||||
|
||||
|
||||
def _apply_weights(
|
||||
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
t = tensors
|
||||
else:
|
||||
t = torch.stack(tensors, dim=0)
|
||||
|
||||
# element-wise weighting of each task tensor
|
||||
# need to unsqueeze weights to match task tensor dimensions
|
||||
# for multiplication to apply element-wise
|
||||
while len(t.shape) > len(w.shape):
|
||||
w = w.unsqueeze(-1)
|
||||
return t * w
|
||||
|
||||
|
||||
class MergeStrategy(ABC):
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LinearMerge(MergeStrategy):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
return weighted_task_tensors.sum(dim=0)
|
||||
|
||||
|
||||
class TiesMerge(MergeStrategy):
|
||||
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# sparsify
|
||||
task_tensors = [
|
||||
prune(tensor, self.density, method="magnitude") for tensor in task_tensors
|
||||
]
|
||||
task_tensors = torch.stack(task_tensors, dim=0)
|
||||
|
||||
# elect sign before applying weights
|
||||
majority_sign_mask = calculate_majority_sign_mask(
|
||||
task_tensors, method=self.majority_sign_method
|
||||
)
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
|
||||
# disjoint merge
|
||||
return disjoint_merge(weighted_task_tensors, majority_sign_mask)
|
||||
|
||||
|
||||
class DareLinearMerge(MergeStrategy):
|
||||
def __init__(self, density: float, **kwargs):
|
||||
self.density = density
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# sparsify
|
||||
task_tensors = [
|
||||
prune(tensor, self.density, method="random", rescale=True)
|
||||
for tensor in task_tensors
|
||||
]
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
return weighted_task_tensors.sum(dim=0)
|
||||
|
||||
|
||||
class DareTiesMerge(MergeStrategy):
|
||||
def __init__(self, density: float, majority_sign_method: str = "total", **kwargs):
|
||||
self.density = density
|
||||
self.majority_sign_method = majority_sign_method
|
||||
|
||||
def merge(
|
||||
self, task_tensors: List[torch.Tensor], weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# sparsify
|
||||
task_tensors = [
|
||||
prune(tensor, self.density, method="random", rescale=True)
|
||||
for tensor in task_tensors
|
||||
]
|
||||
task_tensors = torch.stack(task_tensors, dim=0)
|
||||
|
||||
# elect sign before applying weights
|
||||
majority_sign_mask = calculate_majority_sign_mask(
|
||||
task_tensors, method=self.majority_sign_method
|
||||
)
|
||||
weighted_task_tensors = _apply_weights(task_tensors, weights)
|
||||
|
||||
# disjoint merge
|
||||
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
|
||||
return mixed_task_tensors
|
||||
|
||||
|
||||
strategy_registry: Dict[str, Type[MergeStrategy]] = {
|
||||
"linear": LinearMerge,
|
||||
"ties": TiesMerge,
|
||||
"dare_linear": DareLinearMerge,
|
||||
"dare_ties": DareTiesMerge,
|
||||
}
|
||||
|
||||
|
||||
def merge_adapters(
|
||||
adapters: List[Tuple["ModuleMap", "LoraConfig"]],
|
||||
merge_params: AdapterParameters,
|
||||
) -> Tuple["ModuleMap", "LoraConfig"]:
|
||||
# strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()
|
||||
strategy_name = "linear"
|
||||
|
||||
weights = merge_params.weights
|
||||
if not weights:
|
||||
weights = torch.ones(len(adapters))
|
||||
else:
|
||||
weights = torch.tensor(weights)
|
||||
|
||||
merge_config = {
|
||||
"density": merge_params.density,
|
||||
# "majority_sign_method": MajoritySignMethodEnum.Name(
|
||||
# merge_params.majority_sign_method
|
||||
# ).lower(),
|
||||
"majority_sign_method": "total",
|
||||
}
|
||||
merge_strategy = strategy_registry[strategy_name](**merge_config)
|
||||
|
||||
module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(
|
||||
lambda: defaultdict(lambda: defaultdict(list))
|
||||
)
|
||||
lora_configs = []
|
||||
weight_name_to_adapter_idx = defaultdict(list)
|
||||
|
||||
# input is list of (module_map, lora_config) tuples
|
||||
# convert into dict[k][param_name] -> list of tensors
|
||||
for idx, (module_map, lora_config) in enumerate(adapters):
|
||||
for weight_name, data in module_map.items():
|
||||
weight_name_to_adapter_idx[weight_name].append(idx)
|
||||
for k, (param_data, param_name) in data.items():
|
||||
module_maps[weight_name][k][param_name].append(param_data)
|
||||
lora_configs.append(lora_config)
|
||||
|
||||
# validate lora configs are compatible
|
||||
_validate_lora_configs(lora_configs)
|
||||
|
||||
# merge tensors for each module such that we have a single ModuleMap:
|
||||
# dict[k] -> merged tensor
|
||||
merged_module_map: "ModuleMap" = defaultdict(dict)
|
||||
for weight_name, data in module_maps.items():
|
||||
indices = weight_name_to_adapter_idx[weight_name]
|
||||
param_weights = weights[indices]
|
||||
for k, param_data in data.items():
|
||||
for param_name, tensors in param_data.items():
|
||||
merged_tensor = merge_strategy.merge(tensors, param_weights)
|
||||
merged_module_map[weight_name][k] = (merged_tensor, param_name)
|
||||
|
||||
# merge lora configs
|
||||
merged_lora_config = _merge_lora_configs(lora_configs)
|
||||
|
||||
return merged_module_map, merged_lora_config
|
||||
|
||||
|
||||
def _validate_lora_configs(lora_configs: List["LoraConfig"]):
|
||||
# check that all configs have the same rank
|
||||
ranks = set(lora_config.r for lora_config in lora_configs)
|
||||
if len(ranks) > 1:
|
||||
raise ValueError(
|
||||
f"unable to merge adapters, lora configs have different ranks: {ranks}"
|
||||
)
|
||||
|
||||
if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):
|
||||
raise ValueError(
|
||||
"unable to merge adapters, lora configs have no target modules"
|
||||
)
|
||||
|
||||
|
||||
def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig":
|
||||
merged_lora_config = copy.copy(lora_configs[0])
|
||||
|
||||
# merge target modules as a union operation
|
||||
merged_target_modules = sorted(
|
||||
set(
|
||||
module
|
||||
for lora_config in lora_configs
|
||||
for module in lora_config.target_modules
|
||||
)
|
||||
)
|
||||
merged_lora_config.target_modules = merged_target_modules
|
||||
|
||||
return merged_lora_config
|
|
@ -0,0 +1,108 @@
|
|||
# coding=utf-8
|
||||
# From: https://github.com/huggingface/peft/pull/1364
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
# Modifications by Predibase, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
|
||||
"""
|
||||
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
|
||||
`density`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to prune.
|
||||
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||
"""
|
||||
mask = torch.zeros_like(tensor).reshape(-1)
|
||||
k = int(density * tensor.reshape(-1).shape[0])
|
||||
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
|
||||
mask[top_k[1]] = 1
|
||||
return tensor * mask.reshape(tensor.shape)
|
||||
|
||||
|
||||
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
|
||||
"""
|
||||
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
|
||||
`density`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to prune.
|
||||
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
|
||||
"""
|
||||
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
|
||||
pruned_tensor = tensor * mask
|
||||
if rescale:
|
||||
torch.div(input=pruned_tensor, other=density)
|
||||
return pruned_tensor
|
||||
|
||||
|
||||
def prune(
|
||||
tensor: torch.Tensor,
|
||||
density: float,
|
||||
method: Literal["magnitude", "random"],
|
||||
rescale: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prune the values of task tensors based on the `method`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to prune.
|
||||
density (`float`):The fraction of values to preserve. Should be in [0,1].
|
||||
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
|
||||
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
|
||||
"""
|
||||
if density >= 1:
|
||||
return tensor
|
||||
elif density < 0:
|
||||
raise ValueError("Density should be >= 0, got {density}")
|
||||
if method == "magnitude":
|
||||
return magnitude_based_pruning(tensor, density)
|
||||
elif method == "random":
|
||||
return random_pruning(tensor, density, rescale=rescale)
|
||||
else:
|
||||
raise ValueError(f"Unknown method {method}")
|
||||
|
||||
|
||||
def calculate_majority_sign_mask(
|
||||
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
|
||||
):
|
||||
"""
|
||||
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):The tensor to get the mask from.
|
||||
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
|
||||
"""
|
||||
|
||||
sign = tensor.sign()
|
||||
if method == "total":
|
||||
sign_magnitude = (sign * tensor.abs()).sum(dim=0)
|
||||
elif method == "frequency":
|
||||
sign_magnitude = sign.sum(dim=0)
|
||||
else:
|
||||
raise RuntimeError(f'Unimplemented mask method "{method}"')
|
||||
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
|
||||
return sign == majority_sign
|
||||
|
||||
|
||||
def disjoint_merge(task_tensors, majority_sign_mask):
|
||||
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
|
||||
num_params_preserved = majority_sign_mask.sum(dim=0)
|
||||
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import json
|
||||
from typing import Union
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
|
@ -43,3 +43,26 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||
model.config.save_pretrained(cache_dir)
|
||||
tokenizer.save_pretrained(cache_dir)
|
||||
|
||||
|
||||
def download_peft(
|
||||
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
|
||||
):
|
||||
torch_dtype = torch.float16
|
||||
try:
|
||||
_model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
except Exception:
|
||||
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
logger.info("Peft model downloaded.")
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/segments.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find_segments(
|
||||
adapter_indices: Union[torch.Tensor, List[int]]
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
segments = [0]
|
||||
segment_indices = []
|
||||
|
||||
if isinstance(adapter_indices, torch.Tensor):
|
||||
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
|
||||
adapter_indices = adapter_indices.cpu().tolist()
|
||||
|
||||
start_index = 0
|
||||
for i in range(1, len(adapter_indices)):
|
||||
if adapter_indices[i] != adapter_indices[i - 1]:
|
||||
segments.append(i)
|
||||
segment_indices.append(adapter_indices[i - 1])
|
||||
start_index = i
|
||||
|
||||
# Handle the last segment
|
||||
if start_index < len(adapter_indices):
|
||||
segments.append(len(adapter_indices))
|
||||
segment_indices.append(adapter_indices[-1])
|
||||
|
||||
return segments, segment_indices
|
||||
|
||||
|
||||
class SegmentConcatBuilder:
|
||||
def __init__(self):
|
||||
self.adapter_segment_indices = []
|
||||
self.adapter_segment_tensors = []
|
||||
|
||||
def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
|
||||
# Update adapter segments
|
||||
if self.adapter_segment_tensors:
|
||||
# Because we have already processed at least one batch, remove the 0 start index
|
||||
# from this batch denoting the beginning of the segment, then offset all segment
|
||||
# positions by the value of the last segment in the previous batch to account for
|
||||
# the concatenation.
|
||||
adapter_segments = (
|
||||
adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
|
||||
)
|
||||
|
||||
if (
|
||||
self.adapter_segment_indices
|
||||
and self.adapter_segment_indices[-1] == segment_indices[0]
|
||||
):
|
||||
# If the last segment in the previous batch is the same as the first segment in this batch,
|
||||
# then we merge them together into a single segment. In effect, this means removing it from
|
||||
# the segment indices of this batch, and extending the segment span by removing the segment
|
||||
# end index from the previous batch.
|
||||
segment_indices = segment_indices[1:]
|
||||
self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]
|
||||
|
||||
self.adapter_segment_indices.extend(segment_indices)
|
||||
self.adapter_segment_tensors.append(adapter_segments)
|
||||
|
||||
def build(self) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices
|
|
@ -0,0 +1,248 @@
|
|||
# Origin: https://github.com/predibase/lorax
|
||||
# Path: lorax/server/lorax_server/utils/sgmv.py
|
||||
# License: Apache License Version 2.0, January 2004
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
import punica_kernels as _kernels
|
||||
|
||||
HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
|
||||
except ImportError:
|
||||
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
|
||||
_kernels = None
|
||||
HAS_SGMV = False
|
||||
|
||||
|
||||
MIN_SGMV_RANK = 8
|
||||
MIN_RANK_CUSTOM = 16
|
||||
MAX_RANK_CUSTOM = 128
|
||||
SGMV_BLOCK_SIZE = 16
|
||||
BGMV_MAX_RANK = 64
|
||||
|
||||
|
||||
def has_sgmv() -> bool:
|
||||
return HAS_SGMV
|
||||
|
||||
|
||||
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
||||
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
||||
if not has_sgmv():
|
||||
return t
|
||||
|
||||
# tensor parallelism will result in effective rank being divided by world_size,
|
||||
# so we need to scale the min rank to offset that effect
|
||||
min_rank = MIN_SGMV_RANK * world_size
|
||||
|
||||
# if we're at or below the min rank, pad up to the min rank
|
||||
# otherwise, pad to the nearest multiple of the block size
|
||||
current_rank = t.size(dim)
|
||||
target_rank = (
|
||||
min_rank
|
||||
if current_rank <= min_rank
|
||||
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
||||
)
|
||||
if current_rank == target_rank:
|
||||
return t
|
||||
|
||||
pad_size = target_rank - current_rank
|
||||
|
||||
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
||||
pad = [0, 0] * t.dim()
|
||||
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
||||
pad = tuple(pad)
|
||||
|
||||
return F.pad(t, pad, mode="constant", value=0.0)
|
||||
|
||||
|
||||
def use_cutlass_shrink(lora_rank: int) -> bool:
|
||||
return lora_rank < MIN_RANK_CUSTOM
|
||||
|
||||
|
||||
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
||||
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
||||
return t.transpose(0, 1)
|
||||
return t
|
||||
|
||||
|
||||
# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
|
||||
def add_lora_sgmv_cutlass(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.Tensor,
|
||||
s_end: torch.Tensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||
Weight matrix shape: `[num_layers, R, H1]`.
|
||||
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
||||
Weight matrix shape: `[num_layers, R, H2]`.
|
||||
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
||||
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
"""
|
||||
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
||||
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
||||
_add_lora_sgmv_cutlass_legacy(
|
||||
y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
|
||||
)
|
||||
return
|
||||
|
||||
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
||||
tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
||||
|
||||
|
||||
def _add_lora_sgmv_cutlass_legacy(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
):
|
||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
||||
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
||||
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
||||
tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
|
||||
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
|
||||
return torch.empty((size,), dtype=torch.uint8, device=device)
|
||||
|
||||
|
||||
def get_tmp_expand_size(size: int) -> int:
|
||||
return _kernels.sgmv_cutlass_tmp_size(size)
|
||||
|
||||
|
||||
def get_tmp_tensors(
|
||||
nsegments: int, lora_rank: int, device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if use_cutlass_shrink(lora_rank) and has_sgmv():
|
||||
tmp = get_tmp_tensor_for_size(nsegments, device)
|
||||
return tmp, tmp
|
||||
else:
|
||||
tmp_shrink = get_tmp_tensor(device)
|
||||
tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device)
|
||||
return tmp_shrink, tmp_expand
|
||||
|
||||
|
||||
def lora_a_sgmv_cutlass(
|
||||
x: torch.Tensor,
|
||||
tmp: torch.Tensor,
|
||||
wa_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
lora_rank: int,
|
||||
) -> torch.Tensor:
|
||||
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
||||
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
||||
_kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
else:
|
||||
_kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
||||
return v
|
||||
|
||||
|
||||
def lora_b_sgmv_cutlass(
|
||||
y: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
tmp: torch.Tensor,
|
||||
wb_ptr: torch.Tensor,
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
||||
|
||||
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
v: Shape: `[B, R]`. Temporary vector.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
|
||||
wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
|
||||
|
||||
def add_lora_a_bgmv(
|
||||
v: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
||||
|
||||
|
||||
def add_lora_b_bgmv(
|
||||
y: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
wb_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
_kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
||||
|
||||
|
||||
def segmented_matmul(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w: List[torch.Tensor],
|
||||
b: List[torch.Tensor],
|
||||
s_start: torch.IntTensor,
|
||||
s_end: torch.IntTensor,
|
||||
):
|
||||
for i in range(len(w)):
|
||||
if s_end[i] - s_start[i] <= 0:
|
||||
continue
|
||||
|
||||
xi = x[s_start[i] : s_end[i]]
|
||||
wi = w[i]
|
||||
bi = b[i]
|
||||
y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)
|
Loading…
Reference in New Issue