diff --git a/Dockerfile b/Dockerfile index d5750ac8..58df06e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -145,6 +145,13 @@ COPY server/marlin/ . # Build specific version of transformers RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +# Build Lorax Punica kernels +FROM kernel-builder as lorax-punica-builder +WORKDIR /usr/src +COPY server/Makefile-lorax-punica Makefile +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder WORKDIR /usr/src @@ -215,6 +222,7 @@ COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86 COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from marlin kernels builder COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages @@ -266,4 +274,4 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh ENTRYPOINT ["/tgi-entrypoint.sh"] -CMD ["--json-output"] +# CMD ["--json-output"] diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b82d23ba..5e739703 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -157,6 +157,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + adapter_id: None, }) .collect(); diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7599562a..c9b4efd9 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -60,6 +60,9 @@ - local: conceptual/speculation title: Speculation (Medusa, ngram) - local: conceptual/guidance - title: How Guidance Works (via outlines) + title: How Guidance Works (via outlines + - local: conceptual/lora + title: LoRA (Low-Rank Adaptation) + title: Conceptual Guides diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index f6175925..5e40146f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -416,6 +416,14 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] +``` +## LORA_ADAPTERS +```shell + --lora-adapters + Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request + + [env: LORA_ADAPTERS=] + ``` ## HELP ```shell diff --git a/docs/source/conceptual/lora.md b/docs/source/conceptual/lora.md new file mode 100644 index 00000000..08df767c --- /dev/null +++ b/docs/source/conceptual/lora.md @@ -0,0 +1,65 @@ +# LoRA (Low-Rank Adaptation) + +## What is LoRA? + +LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task. + +LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed. + +## How is it used? + +LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA: + +Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as: + +- fine-tuning a language model on a small dataset +- fine-tuning a language model on a domain-specific dataset +- fine-tuning a language model on a dataset with limited labels + +## Optimizing Inference with LoRA + +LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models. + +## Serving multiple LoRA adapters with TGI + +Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned. + +In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset. + +Text Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library. + +### Specifying LoRA models + +To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example: + +```bash +LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia +``` + +In the server logs, you will see the following message: + +```txt +Loading adapter weights into model: predibase/customer_support +Loading adapter weights into model: predibase/dbpedia +``` + +## Generate text + +You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example: + +```json +curl 127.0.0.1:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "Hello who are you?", + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/customer_support" + } +}' +``` + +> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon. + +An updated tutorial with detailed examples will be published soon. Stay tuned! diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c7529604..816fa5f3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -452,6 +452,11 @@ struct Args { /// Control the maximum number of inputs that a client can send in a single request #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + + /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during + /// startup that will be available to callers via the `adapter_id` field in a request. + #[clap(long, env)] + lora_adapters: Option, } #[derive(Debug)] @@ -485,6 +490,7 @@ fn shard_manager( max_total_tokens: usize, max_batch_size: Option, max_input_tokens: usize, + lora_adapters: Option, otlp_endpoint: Option, otlp_service_name: String, log_level: LevelFilter, @@ -620,6 +626,11 @@ fn shard_manager( envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); } + // Lora Adapters + if let Some(lora_adapters) = lora_adapters { + envs.push(("LORA_ADAPTERS".into(), lora_adapters.into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -1060,6 +1071,7 @@ fn spawn_shards( let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; let max_batch_size = args.max_batch_size; + let lora_adapters = args.lora_adapters.clone(); thread::spawn(move || { shard_manager( model_id, @@ -1085,6 +1097,7 @@ fn spawn_shards( max_total_tokens, max_batch_size, max_input_tokens, + lora_adapters, otlp_endpoint, otlp_service_name, max_log_level, diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd..926c878e 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -134,6 +134,8 @@ message Request { repeated uint32 blocks = 9; /// Paged attention slots repeated uint32 slots = 10; + /// LORA adapter index + optional string adapter_id = 11; } message Batch { diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 9a3892fb..a996b14f 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -177,6 +177,7 @@ impl Client { }), prefill_logprobs: true, top_n_tokens: 20, + adapter_id: None, }); n_tokens += max_input_length; diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 94002f55..ae8a899b 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -244,6 +244,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + adapter_id: None, }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 3725c03e..93cf9469 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -429,6 +429,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142a..ba65b9b6 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -351,6 +351,7 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -491,6 +492,7 @@ mod tests { stop_sequences: vec![], }, top_n_tokens: 0, + adapter_id: None, }, response_tx, span: info_span!("entry"), diff --git a/router/src/lib.rs b/router/src/lib.rs index 5d201937..126726c6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -302,6 +302,11 @@ pub(crate) struct GenerateParameters { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub grammar: Option, + + /// Lora adapter id + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub adapter_id: Option, } fn default_max_new_tokens() -> Option { @@ -328,6 +333,7 @@ fn default_parameters() -> GenerateParameters { seed: None, top_n_tokens: None, grammar: None, + adapter_id: None, } } diff --git a/router/src/server.rs b/router/src/server.rs index aa872df9..6e6b93b6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -673,6 +673,7 @@ async fn completions( seed, top_n_tokens: None, grammar: None, + ..Default::default() }, }) .collect(); @@ -1115,6 +1116,7 @@ async fn chat_completions( seed, top_n_tokens: req.top_logprobs, grammar, + ..Default::default() }, }; diff --git a/router/src/validation.rs b/router/src/validation.rs index bb9ad318..e2bf5a5d 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -202,6 +202,7 @@ impl Validation { decoder_input_details, top_n_tokens, grammar, + adapter_id, .. } = request.parameters; @@ -383,6 +384,7 @@ impl Validation { parameters, stopping_parameters, top_n_tokens, + adapter_id, }) } @@ -678,6 +680,7 @@ pub(crate) struct ValidGenerateRequest { pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, pub top_n_tokens: u32, + pub adapter_id: Option, } #[derive(Error, Debug)] diff --git a/server/Makefile b/server/Makefile index 5257b876..0099c56a 100644 --- a/server/Makefile +++ b/server/Makefile @@ -4,6 +4,7 @@ include Makefile-vllm include Makefile-awq include Makefile-eetq include Makefile-selective-scan +include Makefile-lorax-punica unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-lorax-punica b/server/Makefile-lorax-punica new file mode 100644 index 00000000..72f06f76 --- /dev/null +++ b/server/Makefile-lorax-punica @@ -0,0 +1,12 @@ +lorax_punica_commit := c71861a653412267dc27ec86013dd945ce3474bc + +build-lorax-punica: + if [ ! -d 'lorax-punica' ]; then \ + git clone --no-checkout https://github.com/predibase/lorax.git lorax-punica; \ + fi + cd lorax-punica && git sparse-checkout set server/punica_kernels && git checkout $(lorax_punica_commit) + cd lorax-punica && git submodule update --init --recursive + cd lorax-punica/server/punica_kernels && python setup.py build + +install-lorax-punica: build-lorax-punica + cd lorax-punica/server/punica_kernels && python setup.py install diff --git a/server/tests/models/test_model.py b/server/tests/models/test_model.py index 32bcd45f..8441e8c6 100644 --- a/server/tests/models/test_model.py +++ b/server/tests/models/test_model.py @@ -17,7 +17,12 @@ def get_test_model(): tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") model = TestModel( - torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") + "test_model_id", + torch.nn.Linear(1, 1), + tokenizer, + False, + torch.float32, + torch.device("cpu"), ) return model diff --git a/server/text_generation_server/adapters/__init__.py b/server/text_generation_server/adapters/__init__.py new file mode 100644 index 00000000..8697cb9e --- /dev/null +++ b/server/text_generation_server/adapters/__init__.py @@ -0,0 +1,13 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/__init__.py +# License: Apache License Version 2.0, January 2004 + +from text_generation_server.adapters.weights import ( + AdapterBatchData, + AdapterBatchMetadata, +) + +__all__ = [ + "AdapterBatchData", + "AdapterBatchMetadata", +] diff --git a/server/text_generation_server/adapters/config.py b/server/text_generation_server/adapters/config.py new file mode 100644 index 00000000..5261d4b5 --- /dev/null +++ b/server/text_generation_server/adapters/config.py @@ -0,0 +1,44 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/config.py +# License: Apache License Version 2.0, January 2004 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple + +import torch + +from text_generation_server.adapters.weights import AdapterWeights + +if TYPE_CHECKING: + from text_generation_server.models.model import Model + + +@dataclass +class ModuleMap: + module_name: str + module_weights: Dict[str, Tuple[torch.Tensor, str]] + + +@dataclass +class AdapterConfig(ABC): + base_model_name_or_path: str + + @abstractmethod + def map_weights_for_model( + self, + adapter_weights: Dict[int, AdapterWeights], + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + pass + + @abstractmethod + def load_batched_adapter_weights( + self, + model: "Model", + module_map: ModuleMap, + layer_type: str, + unused_weight_names: Set[str], + dynamic: bool, + ) -> Optional[AdapterWeights]: + pass diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py new file mode 100644 index 00000000..87543be2 --- /dev/null +++ b/server/text_generation_server/adapters/lora.py @@ -0,0 +1,482 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/lora.py +# License: Apache License Version 2.0, January 2004 + +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +from peft import LoraConfig as _LoraConfig +from torch.distributed import ProcessGroup + +from text_generation_server.adapters.config import AdapterConfig, ModuleMap + +from text_generation_server.adapters.weights import ( + AdapterBatchMetadata, + AdapterWeights, + BatchAdapterWeights, +) +from text_generation_server.utils.sgmv import ( + BGMV_MAX_RANK, + MAX_RANK_CUSTOM, + get_tmp_tensors, + orient_for_rank, + pad_rank, + use_cutlass_shrink, +) + +if TYPE_CHECKING: + from text_generation_server.models.model import Model + + +def get_start_stop_idxs_for_rank(offset, size, rank, world_size): + block_size = size // world_size + start = offset + rank * block_size + stop = offset + (rank + 1) * block_size + return start, stop + + +def shard_on_dim( + t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup +): + world_size = process_group.size() + rank = process_group.rank() + + size = t.shape[dim] + start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size) + + if dim == 0: + tensor = t[start:stop] + elif dim == 1: + tensor = t[:, start:stop] + else: + raise NotImplementedError("Let's make that generic when needed") + + return tensor + + +def shard_lora_weights( + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + split_dim: int, + process_group: ProcessGroup, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + # [hidden_size, r] + weights_a = [ + shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a + ] + + # [r, hidden_size] + weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b] + + return weights_a, weights_b + + +@dataclass +class LoraConfig(AdapterConfig): + r: int + target_modules: Optional[Union[List[str], str]] + fan_in_fan_out: bool + lora_alpha: int + use_rslora: bool + + def map_weights_for_model( + self, + adapter_weights: Dict[int, AdapterWeights], + weight_names: Tuple[str], + ) -> Tuple[ModuleMap, Set[str]]: + adapter_weight_names = set() + module_map = {} + for weight_name in weight_names: + lora_a_name = f"base_model.model.{weight_name}.lora_A.weight" + lora_b_name = f"base_model.model.{weight_name}.lora_B.weight" + if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights: + continue + + module_map[weight_name] = { + "lora_A": (adapter_weights[lora_a_name], lora_a_name), + "lora_B": (adapter_weights[lora_b_name], lora_b_name), + } + adapter_weight_names.add(lora_a_name) + adapter_weight_names.add(lora_b_name) + return module_map, adapter_weight_names + + def load_batched_adapter_weights( + self, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + dynamic: bool, + ) -> Optional[AdapterWeights]: + return LoraWeights.load( + self, + model, + module_map, + layer_type, + unused_weight_names, + ) + + @classmethod + def load(cls, adapter_id: str, api_token: str) -> "LoraConfig": + hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token) + return cls( + base_model_name_or_path=hf_config.base_model_name_or_path, + r=hf_config.r, + target_modules=hf_config.target_modules, + fan_in_fan_out=hf_config.fan_in_fan_out, + lora_alpha=hf_config.lora_alpha, + use_rslora=( + hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False + ), + ) + + +class LoraWeights(AdapterWeights): + """LoRA weights for a single adapter merged across all layers.""" + + def __init__( + self, + weights_a: List[torch.Tensor], + weights_b: List[torch.Tensor], + adapter_config: LoraConfig, + ): + self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1 + self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1 + + self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r) + self._is_transposed = False + + # [num_layers, hidden_size, r] + weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a] + self._weights_a = torch.stack(weights_a) + + # [num_layers, r, hidden_size] + self._weights_b = torch.stack(weights_b) + + self.adapter_config = adapter_config + + @property + def weights_a(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b(self) -> torch.Tensor: + if self._is_transposed: + self._transpose_weights() + return self._weights_b + + @property + def weights_a_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_a + + @property + def weights_b_t(self) -> torch.Tensor: + if not self._is_transposed: + self._transpose_weights() + return self._weights_b + + def _transpose_weights(self): + if self._use_cutlass_shrink: + # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation + self._weights_a = self._weights_a.transpose(1, 2).contiguous() + self._weights_b = self._weights_b.transpose(1, 2).contiguous() + self._is_transposed = not self._is_transposed + + @classmethod + def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]: + return [BatchLoraWeights] + + @classmethod + def load( + cls, + config: LoraConfig, + model: "Model", + module_map: Dict[str, Dict], + layer_type: str, + unused_weight_names: Set[str], + ) -> Optional[AdapterWeights]: + nlayers = model.get_num_layers_for_type(layer_type) + lora_a_list = [None] * nlayers + lora_b_list = [None] * nlayers + + for layer_id in range(nlayers): + key = (layer_id, layer_type) + weight_name, layer = model.target_to_layer[key] + base_weight = layer.base_layer.linear.weight + base_device = base_weight.device + + if weight_name not in module_map: + # There is no LoRA weight for this layer type in the adapter + return None + + lora_a, lora_a_name = module_map[weight_name]["lora_A"] + lora_a = lora_a.to(base_device, model.dtype) + + lora_b, lora_b_name = module_map[weight_name]["lora_B"] + lora_b = lora_b.to(base_device, model.dtype) + + scale = get_scaling_factor( + config.lora_alpha, + config.r, + uses_rslora=config.use_rslora, + ) + + unused_weight_names.discard(lora_a_name) + unused_weight_names.discard(lora_b_name) + + # Merge scaling factor into lora_b due to associativity of matrix multiplication: + # (A * B) * C = A * (B * C) + lora_a_list[layer_id] = lora_a.transpose(0, 1) + lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale + + # pad lora ranks to be compatible with sgmv + lora_a_list = [ + pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list + ] + lora_b_list = [ + pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list + ] + + if lora_a_list: + # update rank if it was padded + padded_rank = lora_a_list[0].size(1) + config.r = padded_rank + + return LoraWeights( + *shard_lora_weights( + weights_a=lora_a_list, + weights_b=lora_b_list, + split_dim=0 if model.is_row_parallel(layer_type) else 1, + process_group=model.process_group, + ), + config, + ) + + +@dataclass +class RankSegments: + rank: int + + lora_a_ptr: torch.Tensor + lora_b_ptr: torch.Tensor + + # prefill (sgmv) + tmp_shrink: torch.Tensor + tmp_expand: torch.Tensor + segment_starts: torch.Tensor + segment_ends: torch.Tensor + + # decode (bgmv) + indices: torch.Tensor + + +@dataclass +class BatchLoraWeights(BatchAdapterWeights): + lora_a: Dict[int, torch.Tensor] + lora_b: Dict[int, torch.Tensor] + adapter_index_configs: Dict[int, LoraConfig] + rank_data: Dict[int, RankSegments] + use_sgmv: bool + + def has_adapter(self, adapter_index: int) -> bool: + return adapter_index in self.adapter_index_configs + + def can_vectorize(self, pg: ProcessGroup) -> bool: + return all( + rank_data.rank // pg.size() <= MAX_RANK_CUSTOM + for rank_data in self.rank_data.values() + ) + + @classmethod + def key(cls) -> str: + return "lora" + + @classmethod + def load( + self, + adapter_weights: Dict[int, AdapterWeights], + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Optional["BatchLoraWeights"]: + adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()} + adapter_weights = { + k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights) + } + if not adapter_weights: + return None + + first_weights = next(iter(adapter_weights.values())) + device = first_weights.weights_a.device + segment_indices = meta.segment_indices + + lora_a = { + idx: adapter_weights[idx].weights_a + for idx in segment_indices + if idx in adapter_weights + } + lora_b = { + idx: adapter_weights[idx].weights_b + for idx in segment_indices + if idx in adapter_weights + } + + max_rank = max( + ( + adapter_weights[idx].lora_a_r + for idx in segment_indices + if idx in adapter_weights + ), + default=0, + ) + + if prefill or max_rank > BGMV_MAX_RANK: + use_sgmv = True + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + else: + use_sgmv = False + lora_a_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_a_t.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + lora_b_ptr = torch.tensor( + [ + ( + adapter_weights[idx].weights_b_t.data_ptr() + if idx in adapter_weights + else 0 + ) + for idx in segment_indices + ], + dtype=torch.int64, + device=device, + ) + + adapter_index_configs = { + idx: adapter_weights[idx].adapter_config + for idx in segment_indices + if idx in adapter_weights + } + + adapter_to_segment = {v: k for k, v in enumerate(segment_indices)} + + rank_indices = defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx not in adapter_weights: + continue + rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + + if prefill_head_indices is not None: + j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + for head_index in prefill_head_indices: + # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + if head_index < meta.adapter_segments[j]: + prefill_head_segment_ends[-1] += 1 + else: + prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) + prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) + j += 1 + + rank_data = {} + for rank, indices in rank_indices.items(): + tmp_shrink = None + tmp_expand = None + segment_starts = None + segment_ends = None + batch_indices = None + + if use_sgmv: + lora_a_ptr_indices = lora_a_ptr[indices] + tmp_shrink, tmp_expand = get_tmp_tensors( + lora_a_ptr_indices.size(0), rank, device + ) + segment_starts = meta.adapter_segments[indices] + segment_ends = meta.adapter_segments[[i + 1 for i in indices]] + if prefill_head_indices is not None: + for i, segment_index in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_index] + segment_ends[i] = prefill_head_segment_ends[segment_index] + else: + rank_indices = set(indices) + batch_indices = [ + adapter_to_segment[idx] for idx in meta.adapter_indices.tolist() + ] + batch_indices = [ + idx if idx in rank_indices else -1 for idx in batch_indices + ] + batch_indices = torch.tensor( + batch_indices, dtype=torch.int64, device=device + ) + + rank_data[rank] = RankSegments( + rank=rank, + tmp_shrink=tmp_shrink, + tmp_expand=tmp_expand, + lora_a_ptr=lora_a_ptr[indices], + lora_b_ptr=lora_b_ptr[indices], + segment_starts=segment_starts, + segment_ends=segment_ends, + indices=batch_indices, + ) + + return BatchLoraWeights( + lora_a=lora_a, + lora_b=lora_b, + adapter_index_configs=adapter_index_configs, + rank_data=rank_data, + use_sgmv=use_sgmv, + ) + + +def get_scaling_factor( + lora_alpha: int, + r: int, + uses_rslora: bool = False, +) -> float: + """Computes the scaling factor for the lora weights.""" + if uses_rslora: + return lora_alpha / (r**0.5) + return lora_alpha / r + + +def _convert_lora(v: AdapterWeights) -> AdapterWeights: + if hasattr(v, "lora_weights"): + return v.lora_weights + return v diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py new file mode 100644 index 00000000..8f658756 --- /dev/null +++ b/server/text_generation_server/adapters/weights.py @@ -0,0 +1,158 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/adapters/weights.py +# License: Apache License Version 2.0, January 2004 + +from abc import ABC, abstractclassmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Type + +import torch + + +@dataclass +class AdapterBatchMetadata: + # [batch_size] + adapter_indices: torch.Tensor + + # [num_adapters] + adapter_set: Set[int] + + # [num_segments + 1] + adapter_segments: torch.Tensor + + # [num_segments] + # maps from segment index to adapter index, i.e.: + # segment_indices[s] == adapter_indices[i] + segment_indices: List[int] + + +class AdapterWeights(ABC): + @abstractclassmethod + def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]: + pass + + @property + def speculative_tokens(self) -> int: + return 0 + + +class BatchAdapterWeights(ABC): + @abstractclassmethod + def has_adapter(self, adapter_index: int) -> bool: + pass + + @abstractclassmethod + def key(cls) -> str: + pass + + @abstractclassmethod + def load( + cls, + adapter_weights: Dict[int, AdapterWeights], + meta: "AdapterBatchMetadata", + prefill: bool, + prefill_head_indices: torch.Tensor, + ) -> Optional["BatchAdapterWeights"]: + pass + + +class LayerAdapterWeights: + """Adapter weights that apply to a particular layer.""" + + def __init__(self): + self.adapter_weights: Dict[int, AdapterWeights] = {} + + def add_adapter(self, adapter_idx: int, weights: AdapterWeights): + self.adapter_weights[adapter_idx] = weights + + def remove_adapter(self, adapter_idx: int): + if adapter_idx not in self.adapter_weights: + return + del self.adapter_weights[adapter_idx] + + @property + def max_speculative_tokens(self) -> int: + return max( + adapter_weights.speculative_tokens + for adapter_weights in self.adapter_weights.values() + ) + + def is_empty(self) -> bool: + return len(self.adapter_weights) == 0 + + def get_data( + self, + meta: AdapterBatchMetadata, + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> Dict[str, BatchAdapterWeights]: + # bucket adapters by batch class + adapter_batch_types: Dict[ + Type[BatchAdapterWeights], Dict[int, AdapterWeights] + ] = defaultdict(dict) + for adapter_index, adapter_weights in self.adapter_weights.items(): + for batch_type in adapter_weights.get_batch_types(): + adapter_batch_types[batch_type][adapter_index] = adapter_weights + + batch_data = {} + for batch_type, adapter_weights in adapter_batch_types.items(): + batched_weights = batch_type.load( + adapter_weights, meta, prefill, prefill_head_indices + ) + if batched_weights is not None: + batch_data[batch_type.key()] = batched_weights + return batch_data + + +@dataclass +class AdapterBatchData: + meta: AdapterBatchMetadata + + # layer type -> adapter type -> batch weight data + data: Dict[str, Dict[str, BatchAdapterWeights]] + + prefill: bool + + @staticmethod + def from_meta( + meta: AdapterBatchMetadata, + weights: Dict[str, LayerAdapterWeights], + prefill: bool, + prefill_head_indices: Optional[torch.Tensor], + ) -> "AdapterBatchData": + data = {} + for k, v in weights.items(): + if v.is_empty(): + continue + data[k] = v.get_data( + meta, prefill, prefill_head_indices if k == "lm_head" else None + ) + return AdapterBatchData(meta=meta, data=data, prefill=prefill) + + def ranks(self) -> Set[int]: + # TODO(travis): refactor to be less coupled to lora implementation + ranks = set() + for layer_data in self.data.values(): + lora_data = layer_data.get("lora") + if lora_data is None: + continue + + for rank_data in lora_data.rank_data.values(): + ranks.add(rank_data.rank) + + return ranks + + def layer_names(self) -> Set[str]: + return set(self.data.keys()) + + def adapter_keys(self) -> Set[str]: + adapter_keys = set() + for layer_data in self.data.values(): + adapter_keys.update(layer_data.keys()) + return adapter_keys + + @property + def max_rank(self) -> int: + ranks = self.ranks() + return max(ranks) if len(ranks) > 0 else 0 diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 18cad071..68ae95dd 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -79,6 +79,18 @@ def serve( if otlp_endpoint is not None: setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint) + lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) + + # split on comma and strip whitespace + lora_adapter_ids = ( + [x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else [] + ) + + if len(lora_adapter_ids) > 0: + logger.warning( + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + ) + # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value @@ -93,6 +105,7 @@ def serve( ) server.serve( model_id, + lora_adapter_ids, revision, sharded, quantize, @@ -113,6 +126,7 @@ def download_weights( logger_level: str = "INFO", json_output: bool = False, trust_remote_code: bool = False, + merge_lora: bool = False, ): # Remove default handler logger.remove() @@ -143,18 +157,28 @@ def download_weights( ) is not None if not is_local_model: - try: - adapter_config_filename = hf_hub_download( - model_id, revision=revision, filename="adapter_config.json" - ) - utils.download_and_unload_peft( - model_id, revision, trust_remote_code=trust_remote_code - ) - is_local_model = True - utils.weight_files(model_id, revision, extension) - return - except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): - pass + # TODO: maybe reverse the default value of merge_lora? + # currently by default we don't merge the weights with the base model + if merge_lora: + try: + adapter_config_filename = hf_hub_download( + model_id, revision=revision, filename="adapter_config.json" + ) + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + is_local_model = True + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + else: + try: + utils.peft.download_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + except Exception: + pass try: import json diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index c29dd092..32c8d121 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -12,3 +12,9 @@ from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm from text_generation_server.layers.conv import load_conv2d + +from text_generation_server.layers.lora import ( + LoraLinear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, +) diff --git a/server/text_generation_server/layers/lora.py b/server/text_generation_server/layers/lora.py new file mode 100644 index 00000000..0bb6db41 --- /dev/null +++ b/server/text_generation_server/layers/lora.py @@ -0,0 +1,286 @@ +import math +import os +from typing import TYPE_CHECKING, Optional, Tuple, List + +import torch +import torch.distributed +from accelerate import init_empty_weights +from torch import nn +from torch.nn import functional as F +from torch.distributed import ProcessGroup + +from text_generation_server.utils.sgmv import ( + add_lora_a_bgmv, + add_lora_b_bgmv, + has_sgmv, + lora_a_sgmv_cutlass, + lora_b_sgmv_cutlass, + orient_for_rank, +) + +if TYPE_CHECKING: + from text_generation_server.adapters import AdapterBatchData + from text_generation_server.adapters.lora import BatchLoraWeights + + +class LoraLinear(nn.Module): + def __init__( + self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup + ): + super().__init__() + self.base_layer = base_layer + self.layer_id = layer_id + self.process_group = process_group + + def forward_layer_type( + self, + result: torch.Tensor, + input: torch.Tensor, + adapter_data: "AdapterBatchData", + layer_type: str, + start_idx: int, + end_idx: int, + ) -> torch.Tensor: + if adapter_data is None: + return result + data = adapter_data.data.get(layer_type) + data: Optional["BatchLoraWeights"] = ( + data.get("lora") if data is not None else None + ) + + if has_sgmv() and data is not None and data.can_vectorize(self.process_group): + # In tensor-parallel configurations, each GPU processes a specific segment of the output. + # The 'result' tensor represents the full output, which can vary in size based on + # the layer type (e.g., attention vs. feed-forward layers). We define the current + # segment using start_idx and end_idx. If the segment size doesn't match this GPU's + # slice of 'result', we create a zero tensor of the correct size for LoRA computation. + # This approach ensures accurate LoRA application across various layer sizes and + # configurations, adapting to different model architectures and parallelization strategies. + # + # Example scenarios where this is necessary: + # 1. The adapter's size doesn't evenly divide across GPUs. + # 2. We're processing the last segment which might be smaller. + # 3. Different projection layers (q, k, v) have different sizes. + if end_idx - start_idx != result.shape[1]: + proj = torch.zeros_like(result[:, start_idx:end_idx]) + else: + proj = result + + for r, rank_segments in data.rank_data.items(): + lora_a_ptr = rank_segments.lora_a_ptr + lora_b_ptr = rank_segments.lora_b_ptr + + if lora_a_ptr is None or lora_b_ptr is None: + raise ValueError("LoRA data is missing") + + if data.use_sgmv: + # Use SGMV for prefill + v = lora_a_sgmv_cutlass( + input, + rank_segments.tmp_shrink, + lora_a_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + r, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + lora_b_sgmv_cutlass( + proj, + v, + rank_segments.tmp_expand, + lora_b_ptr, + rank_segments.segment_starts, + rank_segments.segment_ends, + self.layer_id, + ) + else: + # Use BGMV for decode + v = torch.zeros( + (input.size(0), r), dtype=input.dtype, device=input.device + ) + # TODO: error with [-1, 0], but not [0, -1] + add_lora_a_bgmv( + v, + input, + lora_a_ptr, + rank_segments.indices, + self.layer_id, + ) + + if self.process_group.size() > 1: + v = self.collect_lora_a(v) + + add_lora_b_bgmv( + proj, + v, + lora_b_ptr, + rank_segments.indices, + self.layer_id, + ) + + if end_idx - start_idx != result.shape[1]: + result[:, start_idx:end_idx] += proj + else: + for adapter_index in adapter_data.meta.adapter_set: + if data is not None and data.has_adapter(adapter_index): + adapter_mask = ( + (adapter_data.meta.adapter_indices == adapter_index) + .to(input.dtype) + .view(-1, 1) + ) + layer_result = self.forward_lora( + input, data, adapter_index, adapter_mask + ) + result[:, start_idx:end_idx] += layer_result + + return result + + def forward_lora( + self, + input: torch.Tensor, + data: "BatchLoraWeights", + adapter_index: int, + adapter_mask: torch.Tensor, + ) -> torch.Tensor: + lora_a = data.lora_a[adapter_index][self.layer_id, :, :] + lora_b = data.lora_b[adapter_index][self.layer_id, :, :] + + lora_a = orient_for_rank(lora_a, lora_b.size(0)) + + a_out = input @ lora_a + if self.process_group.size() > 1: + a_out = self.collect_lora_a(a_out) + + result = (a_out @ lora_b) * adapter_mask + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("Implemented in subclasses") + + +class TensorParallelMultiAdapterLinear(LoraLinear): + def __init__( + self, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): + super().__init__(base_layer, layer_id, process_group) + self.layer_names = layer_names + self.sizes = sizes + + @classmethod + def load( + cls, + base_layer: nn.Module, + layer_id: int, + layer_names: List[str], + sizes: List[int], + process_group: ProcessGroup, + ): + return TensorParallelMultiAdapterLinear( + base_layer, layer_id, layer_names, sizes, process_group + ) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + # noop if no layer names are provided (e.g. for models without adapters) + if self.layer_names is None: + return result + + # handle models like Bloom that have inputs of shape + # (batch_size, sequence_length, hidden_size) + # we need to reshape them to (batch_size * sequence_length, hidden_size) + # for the LoRA computation, then reshape back + prev_shape = result.shape + is_3d = len(input.shape) >= 3 + if is_3d: + input = input.reshape(-1, input.shape[-1]) + result = result.reshape(-1, result.shape[-1]) + + offset = 0 + for i, layer_name in enumerate(self.layer_names): + start_idx = offset // self.process_group.size() + # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple + # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It + # ensures correct slicing of the result tensor, accommodating variations like grouped-query + # attention where k_proj and v_proj differ from q_proj. This allows precise application of + # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the + # different projection sizes across layers and model architectures. + if self.sizes is not None: + offset += self.sizes[i] + end_idx = offset // self.process_group.size() + else: + end_idx = result.shape[1] + + result = self.forward_layer_type( + result, input, adapter_data, layer_name, start_idx, end_idx + ) + + if is_3d: + result = result.reshape(prev_shape) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise. + # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-gather for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + gathered_tensors = [ + torch.empty_like(a_out) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(gathered_tensors, a_out) + return torch.cat(gathered_tensors, dim=1) + + +class TensorParallelAdapterRowLinear(LoraLinear): + def __init__(self, base_layer, layer_id, layer_name, process_group): + super().__init__(base_layer, layer_id, process_group) + self.layer_name = layer_name + + @classmethod + def load(cls, base_layer, layer_id, layer_name, process_group): + return cls(base_layer, layer_id, layer_name, process_group) + + def forward( + self, input: torch.Tensor, adapter_data: "AdapterBatchData" + ) -> torch.Tensor: + result = self.base_layer(input) + + if self.layer_name is None: + return result + + # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285 + stride = result.shape[-1] // self.process_group.size() + start_idx = self.process_group.rank() * stride + end_idx = (self.process_group.rank() + 1) * stride + + self.forward_layer_type( + result, input, adapter_data, self.layer_name, start_idx, end_idx + ) + + return result + + def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor: + # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise. + # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks. + # + # TODO(travis): this is not very efficient as we do an all-reduce for every adapter, + # instead we could pre-allocate a (B, a, r) tensor for all adapters with the same + # rank, compute `a_out` on each, and then slice them into the buffer as shown here: + # https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609 + torch.distributed.all_reduce(a_out, group=self.process_group) + return a_out diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 76dca3dc..648fcee9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,7 +6,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from huggingface_hub import hf_hub_download, HfApi -from typing import Optional +from typing import Optional, List from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate @@ -253,6 +253,7 @@ for data in ModelType: def get_model( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -595,6 +596,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 38006502..17aa12e8 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -90,6 +90,7 @@ class BLOOMSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index e896c831..10c64c66 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -538,6 +538,7 @@ class CausalLM(Model): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 6d315ba5..2850a6f3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index f81bfa10..9d56e4ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 04d05cd6..a4fd4740 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 0c01f56a..7e7510c7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: token_embeds = self.embed_tokens(input_ids) position_embeds = self.embed_positions(position_ids) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0d06d104..c48ed268 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -51,43 +53,61 @@ if SYSTEM == "rocm": raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -def load_attention(config, prefix, weights): +def load_attention(config, prefix, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) + head_size = config.hidden_size // config.num_attention_heads + sizes = None + prefixes = None - # if specific model type, load the correct attention if config.model_type == "phi3": - return TensorParallelColumnLinear.load_qkv( + prefix = f"{prefix}.qkv_proj" + base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=f"{prefix}.qkv_proj", + prefix=prefix, weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) elif config.model_type == "baichuan": - return TensorParallelColumnLinear.load_qkv( + prefix = f"{prefix}.W_pack" + base_layer = TensorParallelColumnLinear.load_qkv( config, - prefix=f"{prefix}.W_pack", + prefix=prefix, weights=weights, bias=bias, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, ) + else: + prefixes = ["q_proj", "k_proj", "v_proj"] + sizes = [ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ] + base_layer = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) - # otherwise, load the default attention based on the number of heads - return TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=bias, + return TensorParallelMultiAdapterLinear.load( + base_layer=base_layer, + layer_id=layer_id, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, ) class FlashLlamaAttention(torch.nn.Module): def __init__( self, + index: int, prefix: str, config, weights, @@ -121,14 +141,23 @@ class FlashLlamaAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = load_attention(config, prefix, weights) + self.query_key_value = load_attention(config, prefix, weights, index) + self.index = index - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + index, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -145,8 +174,9 @@ class FlashLlamaAttention(torch.nn.Module): slots, input_lengths, max_s, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -190,11 +220,13 @@ class FlashLlamaAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class LlamaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, index): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -209,29 +241,54 @@ class LlamaMLP(nn.Module): ), ) ) + prefixes = None + sizes = None + # Fuse gate and up proj bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": - self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( + gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, bias=bias, ) else: - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + prefixes = [f"gate_proj", f"up_proj"] + sizes = [ + config.intermediate_size, + config.intermediate_size, + ] + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=bias, ) - self.down_proj = TensorParallelRowLinear.load( + + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + index, + layer_names=prefixes, + sizes=sizes, + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=bias, ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + index, + "down_proj", + process_group=weights.process_group, + ) + self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -239,7 +296,7 @@ class LlamaMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize - def forward(self, hidden_states): + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" @@ -253,20 +310,27 @@ class LlamaMLP(nn.Module): device="cuda", ) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out) + return self.down_proj(out, adapter_data) else: - gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashLlamaLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, index, prefix, config, weights): super().__init__() self.self_attn = FlashLlamaAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) - self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -289,6 +353,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -303,6 +368,7 @@ class FlashLlamaLayer(nn.Module): slots, input_lengths, max_s, + adapter_data, ) # faster post attention rms norm @@ -310,7 +376,7 @@ class FlashLlamaLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -325,6 +391,7 @@ class FlashLlamaModel(torch.nn.Module): self.layers = nn.ModuleList( [ FlashLlamaLayer( + index=layer_id, prefix=( f"model.layers.{layer_id}" if not prefix @@ -360,6 +427,7 @@ class FlashLlamaModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -382,6 +450,7 @@ class FlashLlamaModel(torch.nn.Module): slots, input_lengths, max_s, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -423,6 +492,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -436,6 +506,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, + adapter_data=adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 77a8a384..d1ba5564 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, get_linear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig): class MistralAttention(torch.nn.Module): - def __init__( - self, - prefix: str, - config, - weights, - ): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.max_past = ( config.sliding_window if config.sliding_window is not None else -1 @@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module): config.num_key_value_heads // weights.process_group.size() ) - self.query_key_value = TensorParallelColumnLinear.load_multi( + query_key_value = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, @@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module): bias=False, ) - self.o_proj = TensorParallelRowLinear.load( + head_size = config.hidden_size // config.num_attention_heads + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + head_size * config.num_attention_heads, + head_size * config.num_key_value_heads, + head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) + + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module): max_s, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -244,19 +263,37 @@ class MistralMLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) + + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) @@ -264,7 +301,7 @@ class MistralMLP(nn.Module): # TODO: This is a hotfix to be removed & properly refactored. self.quantize = config.quantize - def forward(self, hidden_states): + def forward(self, hidden_states, adapter_data): if ( SYSTEM == "rocm" and self.hidden_act == "silu" @@ -278,20 +315,27 @@ class MistralMLP(nn.Module): device="cuda", ) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) - return self.down_proj(out) + return self.down_proj(out, adapter_data) else: - gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + ) + self.mlp = MistralMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) - self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -315,6 +359,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -330,6 +375,7 @@ class MistralLayer(nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ) # faster post attention rms norm @@ -337,7 +383,7 @@ class MistralLayer(nn.Module): attn_output, res ) - mlp_output = self.mlp(normed_attn_res_output) + mlp_output = self.mlp(normed_attn_res_output, adapter_data) return mlp_output, attn_res @@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module): prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, + layer_id=layer_id, ) for layer_id in range(config.num_hidden_layers) ] @@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module): max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward @@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module): input_lengths, max_s, prefill_cache_indices, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: @@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module): max_s, true_max_s, prefill_cache_indices, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2f7619af..2e839d15 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d399be2f..b87fd4ca 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -390,6 +390,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.gpt_neox( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 91c709e4..1f998e5a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): # Unused here pixel_attention_mask: Optional[torch.BoolTensor] = None, image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.text_model.embed_tokens(input_ids) # TODO This is odd but apparently pali gemma position ids start at 1. diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6dda4b2b..3f445f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -400,6 +400,7 @@ class FlashPhiForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index df5a8ae9..69f38c3a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -359,6 +359,7 @@ class Qwen2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 7d3c72a7..04d4ba51 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2ae0908c..badfc367 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -483,6 +483,7 @@ class FlashSantacoderForCausalLM(nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index c3e2e099..f6a2e15d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: true_max_s = max_s if prefill_cache_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 51fd7c02..a83bc1c6 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module): pixel_attention_mask: Optional[torch.BoolTensor] = None, # Unused here image_sizes: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None: diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index de9673aa..9a670140 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module): # Unused for this model pixel_attention_mask=None, image_sizes: Optional[torch.LongTensor] = None, + adapter_data: Optional[torch.Tensor] = None, ): inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7f8268a9..f7678762 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -13,6 +13,7 @@ from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Iterable, Optional, Tuple, List, Type, Dict +from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM @@ -31,6 +32,7 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS import text_generation_server.models.globals as tgi_globals from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION +from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments from text_generation_server.utils.import_utils import ( empty_cache, @@ -114,6 +116,9 @@ class FlashCausalLMBatch(Batch): top_n_tokens: List[int] top_n_tokens_tensor: torch.Tensor + # Adapter metadata for each request + adapter_meta: AdapterBatchMetadata + # Number of blocks in this batch num_blocks: int # Maximum number of blocks @@ -174,6 +179,9 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] + adapter_indices_list = [] + adapter_set = set() + # Cumulative length cumulative_length = 0 cumulative_max_length = 0 @@ -225,6 +233,10 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) + adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append(torch.full((input_length,), adapter_index)) + adapter_set.add(adapter_index) + # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() @@ -296,6 +308,10 @@ class FlashCausalLMBatch(Batch): max_length, input_length + max_new_tokens + speculative_length ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) @@ -339,6 +355,11 @@ class FlashCausalLMBatch(Batch): input_lengths, dtype=torch.int32, device=device ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -393,6 +414,12 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), speculative_ids=None, ) @@ -443,6 +470,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] top_n_tokens = [] + adapter_set = set() num_blocks = 0 max_blocks = 0 @@ -471,6 +499,11 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) + adapter_index = tgi_globals.ADAPTER_TO_INDEX.get( + self.requests[idx].adapter_id, 0 + ) + adapter_set.add(adapter_index) + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -498,6 +531,7 @@ class FlashCausalLMBatch(Batch): # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] @@ -513,6 +547,11 @@ class FlashCausalLMBatch(Batch): # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + return type(self)( batch_id=self.batch_id, requests=requests, @@ -543,6 +582,12 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) @classmethod @@ -596,6 +641,14 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_set = set() + adapter_segment_builder = SegmentConcatBuilder() start_slots = [] block_tables = [] @@ -613,6 +666,7 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 cumulative_slots = 0 + cumulative_adapter_indices_size = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) @@ -637,6 +691,21 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) + all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -680,6 +749,8 @@ class FlashCausalLMBatch(Batch): else None ) + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -710,6 +781,12 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, + adapter_meta=AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ), ) def __len__(self): @@ -719,6 +796,7 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, num_layers: int, @@ -738,6 +816,7 @@ class FlashCausalLM(Model): self.kv_cache = [] super(FlashCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, @@ -895,12 +974,13 @@ class FlashCausalLM(Model): total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size free_memory = get_free_memory(self.device, MEMORY_FRACTION) + batch_num_blocks = batch.num_blocks if batch is not None else 0 num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. - + batch.num_blocks + + batch_num_blocks ) del batch @@ -1001,7 +1081,7 @@ class FlashCausalLM(Model): ) def forward( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: @@ -1080,6 +1160,7 @@ class FlashCausalLM(Model): max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, + adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None @@ -1116,7 +1197,34 @@ class FlashCausalLM(Model): prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None - out, speculative_logits = self.forward(batch) + # Update adapter indices for speculative tokens (if present) + adapter_meta = batch.adapter_meta + if batch.speculative_ids is not None: + B, speculative_length = batch.speculative_ids.shape + new_length = speculative_length + 1 + adapter_indices = ( + adapter_meta.adapter_indices.unsqueeze(-1) + .expand(B, new_length) + .reshape(-1) + ) + adapter_segments = adapter_meta.adapter_segments * new_length + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_meta.segment_indices, + ) + + # Assign pointers to adapter weights + # TODO(travis): don't update this if indices haven't changed + adapter_data = AdapterBatchData.from_meta( + adapter_meta, + self.layer_to_adapter_weights, + prefill, + batch.prefill_head_indices, + ) + + out, speculative_logits = self.forward(batch, adapter_data) if prefill: next_token_logits = ( @@ -1128,8 +1236,13 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + else: next_token_logits = out + next_adapter_indices = batch.adapter_meta.adapter_indices speculate = get_speculate() ( @@ -1195,6 +1308,12 @@ class FlashCausalLM(Model): # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] + # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: @@ -1220,6 +1339,16 @@ class FlashCausalLM(Model): batch.position_ids = next_position_ids + accepted_ids batch.input_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if prefill: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) if prefill and prefill_logprobs: # Get prefill logprobs diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 1077d78e..9f8bcb3f 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -62,6 +62,7 @@ class FlashCohere(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashCohere, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index ffb6d5a6..2aba6a00 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -87,6 +87,7 @@ class FlashDbrx(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashDbrx, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 1b7b2772..aa1ae9ac 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -62,6 +62,7 @@ class FlashGemma(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 2d0f9fcc..323fcafa 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -69,6 +69,7 @@ class FlashGPT2(FlashCausalLM): model = FlashGPT2ForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashGPT2, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 9366706f..d996b9c3 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -1,9 +1,10 @@ +import os import torch import torch.distributed from opentelemetry import trace from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional +from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( @@ -13,12 +14,24 @@ from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, + hub, ) tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import SYSTEM +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + class FlashLlama(FlashCausalLM): def __init__( @@ -29,6 +42,7 @@ class FlashLlama(FlashCausalLM): speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -78,6 +92,7 @@ class FlashLlama(FlashCausalLM): model = FlashLlamaForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), @@ -88,3 +103,69 @@ class FlashLlama(FlashCausalLM): rank=rank, world_size=world_size, ) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 16778ada..209eca83 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -3,7 +3,7 @@ import torch.distributed from opentelemetry import trace from transformers import AutoTokenizer, AutoConfig -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.flash_causal_lm import set_sliding_window @@ -21,6 +21,18 @@ from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + class BaseFlashMistral(FlashCausalLM): def __init__( self, @@ -83,6 +95,7 @@ class BaseFlashMistral(FlashCausalLM): torch.distributed.barrier(group=self.process_group) num_layers, num_kv_heads, head_size = self.get_layer_config(model) super().__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=num_layers, @@ -102,6 +115,75 @@ class BaseFlashMistral(FlashCausalLM): model.model.head_size, ) + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL + class FlashMistral(BaseFlashMistral): def __init__( diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 87ae570c..ac1fd573 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -69,6 +69,7 @@ class FlashNeoXSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashNeoXSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.gpt_neox.layers), diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 0cc67cec..7e108d05 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -90,6 +90,7 @@ class FlashPhi(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashPhi, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 9fcfce9d..23528f0b 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -71,6 +71,7 @@ class FlashQwen2(BaseFlashMistral): torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 6ed1f6f7..b1f75adc 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -78,6 +78,7 @@ class FlashRWSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashRWSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index ab1e4516..e1a7b36e 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -80,6 +80,7 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) super(FlashSantacoderSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 1ac731be..369e9e4c 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -70,6 +70,7 @@ class FlashStarcoder2(BaseFlashMistral): torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f39bd1e9..30c92d90 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -212,6 +212,7 @@ class GalacticaSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 970a673b..cc2f172a 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,6 +1,7 @@ import torch import os from loguru import logger +from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli @@ -32,3 +33,14 @@ MODEL_ID = None def set_model_id(model_id: str): global MODEL_ID MODEL_ID = model_id + + +# NOTE: eventually we should move this into the router and pass back the +# index in all cases. +global ADAPTER_TO_INDEX +ADAPTER_TO_INDEX: Dict[str, int] = None + + +def set_adapter_to_index(adapter_to_index: Dict[str, int]): + global ADAPTER_TO_INDEX + ADAPTER_TO_INDEX = adapter_to_index diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 8d2cb0e1..c37cfb7d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -65,6 +65,7 @@ class GPTNeoxSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index c1fe03e4..f2955bd0 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -83,6 +83,7 @@ class IDEFICSSharded(IdeficsCausalLM): torch.distributed.barrier(group=self.process_group) super(IdeficsCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f507d669..6c562980 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -634,6 +634,7 @@ class IdeficsCausalLM(Model): tokenizer.add_special_tokens({"pad_token": ""}) super(IdeficsCausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 3133a137..9189b45c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -453,6 +453,7 @@ class Mamba(Model): model = MambaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(Mamba, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 4f35b0aa..c90fd38a 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -2,12 +2,24 @@ import inspect import torch from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, TypeVar, Type +from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict +from collections import defaultdict from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, Generation from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse +from text_generation_server.adapters.weights import LayerAdapterWeights +from text_generation_server.utils.adapter import ( + load_and_merge_adapters, + AdapterParameters, + AdapterSource, +) +from loguru import logger + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + B = TypeVar("B", bound=Batch) @@ -15,6 +27,7 @@ B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, + model_id: str, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase, requires_padding: bool, @@ -24,7 +37,9 @@ class Model(ABC): world_size: int = 1, sliding_window: Optional[int] = None, speculate: Optional[int] = None, + adapter_id: str = BASE_MODEL_ADAPTER_ID, ): + self.model_id = model_id self.model = model.eval() self.tokenizer = tokenizer @@ -42,6 +57,13 @@ class Model(ABC): self.world_size = world_size self.sliding_window = sliding_window if sliding_window != -1 else None + self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict( + LayerAdapterWeights + ) + self.target_to_layer = self.adapter_target_to_layer() + self.loaded_adapters = set() + self.static_adapter_id = adapter_id + if speculate is None: speculate = get_speculate() self.speculate = speculate @@ -119,3 +141,136 @@ class Model(ABC): raise RuntimeError( f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" ) + + @property + def supports_adapter_loading(self) -> bool: + return False + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + return {} + + @property + def adapter_layers(self) -> List[str]: + return [] + + @property + def default_traced_adapter_layers(self) -> List[str]: + return [] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 0 + + def is_row_parallel(self, layer_type: str) -> bool: + return False + + @property + def max_speculative_tokens(self) -> int: + return max( + [ + weights.max_speculative_tokens + for weights in self.layer_to_adapter_weights.values() + ], + default=0, + ) + + def load_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + api_token: str, + dynamic: bool = True, + ): + """Loads adapter weights from disk / host memory on the GPU. + + adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded + into model. Otherwise, the adapter weights are applied during the forward + pass and stored separately from the base model parameters. + """ + if adapter_index in self.loaded_adapters: + # Adapter already loaded + return + + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if dynamic and not self.dynamic_adapter_loading_enabled: + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) + + logger.info( + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + ) + weight_names = tuple([v[0] for v in self.target_to_layer.values()]) + ( + module_map, + adapter_config, + adapter_weight_names, + adapter_tokenizer, + ) = load_and_merge_adapters( + self.model_id, + adapter_parameters, + adapter_source, + adapter_index, + weight_names, + api_token, + False, + ) + + unused_weight_names = adapter_weight_names.copy() + for layer_name in self.adapter_layers: + adapter_weights = adapter_config.load_batched_adapter_weights( + self, + module_map, + layer_name, + unused_weight_names, + dynamic, + ) + + if adapter_weights is None: + continue + + layer_weights = self.layer_to_adapter_weights[layer_name] + layer_weights.add_adapter(adapter_index, adapter_weights) + + if len(unused_weight_names) > 0: + logger.warning( + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + ) + + if adapter_tokenizer is not None: + self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer) + + self.loaded_adapters.add(adapter_index) + + def offload_adapter( + self, + adapter_parameters: AdapterParameters, + adapter_source: AdapterSource, + adapter_index: int, + ): + """Offloads the adapter weights from GPU to CPU or disk.""" + if adapter_index not in self.loaded_adapters: + # Adapter already offloaded + return + + if not self.supports_adapter_loading: + raise ValueError("This model does not support adapter loading.") + + if not self.dynamic_adapter_loading_enabled: + raise ValueError( + f"This model was initialized with the adapter {self.static_adapter_id} " + f"and therefore does not support dynamic adapter loading. " + f"Please initialize a new model instance from the base model in " + f"order to use the dynamic adapter loading feature." + ) + + for layer_name in self.adapter_layers: + if layer_name in self.layer_to_adapter_weights: + self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index) + + self.loaded_adapters.remove(adapter_index) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 65180e73..1e79b25f 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -90,6 +90,7 @@ class MPTSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=False, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 1f4fbfcd..6d7d07f5 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -63,6 +63,7 @@ class OPTSharded(CausalLM): torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index d68866c1..93d42b2b 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -60,6 +60,7 @@ class Phi(CausalLM): model = PhiForCausalLM(config, weights) torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 50f6ead8..37ca277b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -62,6 +62,7 @@ class RW(CausalLM): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 323e4324..caddbe19 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -62,6 +62,7 @@ class SantaCoder(CausalLM): ) super(CausalLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3bd09556..d454d804 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -575,6 +575,7 @@ class Seq2SeqLM(Model): tokenizer.bos_token_id = model.config.decoder_start_token_id super(Seq2SeqLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 8e0735e5..adef664c 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM): torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, requires_padding=True, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 8b5819d1..218d1167 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -222,7 +222,9 @@ class VlmCausalLM(BaseFlashMistral): return VlmCausalLMBatch def forward( - self, batch: VlmCausalLMBatch + self, + batch: VlmCausalLMBatch, + adapter_data: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a0347cd8..aee287c6 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -29,7 +29,10 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, set_adapter_to_index +from text_generation_server.utils.adapter import ( + AdapterParameters, +) class SignalHandler: @@ -192,6 +195,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def serve( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool, quantize: Optional[str], @@ -203,6 +207,7 @@ def serve( ): async def serve_inner( model_id: str, + lora_adapter_ids: Optional[List[str]], revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, @@ -211,6 +216,7 @@ def serve( trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" + adapter_to_index = {} if sharded: server_urls = [ unix_socket_template.format(uds_path, rank) @@ -224,6 +230,7 @@ def serve( try: model = get_model( model_id, + lora_adapter_ids, revision, sharded, quantize, @@ -232,10 +239,33 @@ def serve( trust_remote_code, max_input_tokens, ) + + if len(lora_adapter_ids) > 0: + for index, adapter_id in enumerate(lora_adapter_ids): + # TODO: improve non merged adapter loading and long term + # improve adapter loading as a whole + adapter_parameters = AdapterParameters( + adapter_ids=[adapter_id], + weights=None, # will be set to 1 + merge_strategy=0, + density=1.0, + majority_sign_method=0, + ) + adapter_index = index + 1 + adapter_to_index[adapter_id] = adapter_index + model.load_adapter( + adapter_parameters, + None, # adapter_source + adapter_index, + None, # api_token + False, # dynamic + ) + except Exception: logger.exception("Error when initializing model") raise + set_adapter_to_index(adapter_to_index) server = aio.server( interceptors=[ ExceptionInterceptor(), @@ -266,6 +296,13 @@ def serve( set_model_id(model_id) asyncio.run( serve_inner( - model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + model_id, + lora_adapter_ids, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, ) ) diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py new file mode 100644 index 00000000..4e2492de --- /dev/null +++ b/server/text_generation_server/utils/adapter.py @@ -0,0 +1,196 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/adapter.py +# License: Apache License Version 2.0, January 2004 + +import warnings +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING, Set, Tuple + +from safetensors.torch import load_file +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer + +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils.merges.strategies import merge_adapters + +from text_generation_server.utils import hub +from text_generation_server.adapters.lora import LoraConfig + + +if TYPE_CHECKING: + from text_generation_server.adapters.config import AdapterConfig, ModuleMap + + +BASE_MODEL_ADAPTER_ID = "__base_model__" + + +@dataclass +class AdapterParameters: + adapter_ids: Tuple[str] + weights: Tuple[float] + merge_strategy: NotImplemented + density: float + majority_sign_method: NotImplemented + + +@dataclass +class AdapterSource: + adapter_id: str + model_id: str + revision: str + + +def load_and_merge_adapters( + model_id: str, + adapter_parameters: AdapterParameters, + adapter_source: str, + adapter_index: int, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + if len(adapter_parameters.adapter_ids) == 1: + return load_module_map( + model_id, + adapter_parameters.adapter_ids[0], + adapter_source, + weight_names, + api_token, + trust_remote_code, + ) + + adapter_params = AdapterParametersContainer( + adapter_parameters, adapter_source, adapter_index + ) + return _load_and_merge( + model_id, adapter_params, weight_names, api_token, trust_remote_code + ) + + +@dataclass +class AdapterParametersContainer: + adapter_parameters: AdapterParameters + adapter_source: str + adapter_index: int + + def __hash__(self) -> int: + return self.adapter_index + + +@lru_cache(maxsize=32) +def _load_and_merge( + model_id: str, + adapter_params: AdapterParametersContainer, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + params = adapter_params.adapter_parameters + + adapters_to_merge = [] + merged_weight_names = set() + tokenizer = None + for adapter_id in params.adapter_ids: + if adapter_id == BASE_MODEL_ADAPTER_ID: + raise ValueError("Base model adapter cannot be merged.") + + module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( + load_module_map( + model_id, + adapter_id, + adapter_params.adapter_source, + weight_names, + api_token, + trust_remote_code, + ) + ) + + adapters_to_merge.append((module_map, adapter_config)) + merged_weight_names = merged_weight_names.union(adapter_weight_names) + if tokenizer is None: + tokenizer = adapter_tokenizer + + if len(adapters_to_merge) == 0: + raise ValueError("No adapters to merge.") + + module_map, adapter_config = merge_adapters(adapters_to_merge, params) + return module_map, adapter_config, merged_weight_names, tokenizer + + +def check_architectures( + model_id: str, + adapter_id: str, + adapter_config: "AdapterConfig", + trust_remote_code: bool = False, +): + try: + if not adapter_config.base_model_name_or_path: + # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None) + return + + expected_config = AutoConfig.from_pretrained( + model_id, trust_remote_code=trust_remote_code + ) + model_config = AutoConfig.from_pretrained( + adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code + ) + except Exception as e: + warnings.warn( + f"Unable to check architecture compatibility for adapter '{adapter_id}' " + f"against model '{model_id}'. Assuming they are compatible. Error: {e}" + ) + return + + if model_config.architectures == expected_config.architectures: + warnings.warn( + f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. " + f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + else: + # TODO(travis): revisit this when we support clasification heads which will not use CausalLM + raise ValueError( + f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead." + ) + + +@lru_cache(maxsize=128) +def load_module_map( + model_id: str, + adapter_id: str, + adapter_source: str, + weight_names: Tuple[str], + api_token: str, + trust_remote_code: bool = False, +) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]: + revision = "main" + + adapter_config = LoraConfig.load(adapter_id, api_token) + if adapter_config.base_model_name_or_path != model_id: + check_architectures(model_id, adapter_id, adapter_config, trust_remote_code) + + adapter_filenames = hub._cached_adapter_weight_files( + adapter_id, revision=revision, extension=".safetensors" + ) + + try: + adapter_tokenizer = AutoTokenizer.from_pretrained( + adapter_config.config_path, + token=api_token, + trust_remote_code=trust_remote_code, + ) + except Exception: + # Adapter does not have a tokenizer, so fallback to base model tokenizer + adapter_tokenizer = None + + # load adapter weights from all shards (should have relatively small memory footprint) + adapter_weights = {} + for filename in adapter_filenames: + adapter_weights.update(load_file(filename)) + + # map the model weights to the relevant adapter weights (LoRA A and B matrices) + module_map, adapter_weight_names = adapter_config.map_weights_for_model( + adapter_weights, weight_names + ) + return module_map, adapter_config, adapter_weight_names, adapter_tokenizer diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index b56484f6..db412aeb 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] +def _cached_adapter_weight_files( + adapter_id: str, revision: Optional[str], extension: str +) -> List[str]: + """Guess weight files from the cached revision snapshot directory""" + d = _get_cached_revision_directory(adapter_id, revision) + if not d: + return [] + filenames = _adapter_weight_files_from_dir(d, extension) + return filenames + + def _cached_weight_files( model_id: str, revision: Optional[str], extension: str ) -> List[str]: @@ -60,6 +71,33 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: return filenames +def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_files_from_dir, that's also what is done there + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "training" not in f + ] + return filenames + + +def _adapter_config_files_from_dir(d: Path) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_files_from_dir, that's also what is done there + root, _, files = next(os.walk(str(d))) + filenames = [ + os.path.join(root, f) + for f in files + if f.endswith(".json") and "arguments" not in f and "args" not in f + ] + return filenames + + def _get_cached_revision_directory( model_id: str, revision: Optional[str] ) -> Optional[Path]: diff --git a/server/text_generation_server/utils/merges/strategies.py b/server/text_generation_server/utils/merges/strategies.py new file mode 100644 index 00000000..3b885313 --- /dev/null +++ b/server/text_generation_server/utils/merges/strategies.py @@ -0,0 +1,223 @@ +import copy +from abc import ABC +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union + +import torch + + +class AdapterParameters: + def __init__( + self, adapter_ids, weights, merge_strategy, density, majority_sign_method + ): + self.adapter_ids = adapter_ids + self.weights = weights + self.merge_strategy = merge_strategy + self.density = density + self.majority_sign_method = majority_sign_method + + +from text_generation_server.utils.merges.utils import ( + calculate_majority_sign_mask, + disjoint_merge, + prune, +) + +if TYPE_CHECKING: + from text_generation_server.adapters.lora import LoraConfig + from text_generation_server.utils.adapter import ModuleMap + + +def _apply_weights( + tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor +) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + t = tensors + else: + t = torch.stack(tensors, dim=0) + + # element-wise weighting of each task tensor + # need to unsqueeze weights to match task tensor dimensions + # for multiplication to apply element-wise + while len(t.shape) > len(w.shape): + w = w.unsqueeze(-1) + return t * w + + +class MergeStrategy(ABC): + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() + + +class LinearMerge(MergeStrategy): + def __init__(self, **kwargs): + pass + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class TiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="magnitude") for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + return disjoint_merge(weighted_task_tensors, majority_sign_mask) + + +class DareLinearMerge(MergeStrategy): + def __init__(self, density: float, **kwargs): + self.density = density + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + weighted_task_tensors = _apply_weights(task_tensors, weights) + return weighted_task_tensors.sum(dim=0) + + +class DareTiesMerge(MergeStrategy): + def __init__(self, density: float, majority_sign_method: str = "total", **kwargs): + self.density = density + self.majority_sign_method = majority_sign_method + + def merge( + self, task_tensors: List[torch.Tensor], weights: torch.Tensor + ) -> torch.Tensor: + # sparsify + task_tensors = [ + prune(tensor, self.density, method="random", rescale=True) + for tensor in task_tensors + ] + task_tensors = torch.stack(task_tensors, dim=0) + + # elect sign before applying weights + majority_sign_mask = calculate_majority_sign_mask( + task_tensors, method=self.majority_sign_method + ) + weighted_task_tensors = _apply_weights(task_tensors, weights) + + # disjoint merge + mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask) + return mixed_task_tensors + + +strategy_registry: Dict[str, Type[MergeStrategy]] = { + "linear": LinearMerge, + "ties": TiesMerge, + "dare_linear": DareLinearMerge, + "dare_ties": DareTiesMerge, +} + + +def merge_adapters( + adapters: List[Tuple["ModuleMap", "LoraConfig"]], + merge_params: AdapterParameters, +) -> Tuple["ModuleMap", "LoraConfig"]: + # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower() + strategy_name = "linear" + + weights = merge_params.weights + if not weights: + weights = torch.ones(len(adapters)) + else: + weights = torch.tensor(weights) + + merge_config = { + "density": merge_params.density, + # "majority_sign_method": MajoritySignMethodEnum.Name( + # merge_params.majority_sign_method + # ).lower(), + "majority_sign_method": "total", + } + merge_strategy = strategy_registry[strategy_name](**merge_config) + + module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) + lora_configs = [] + weight_name_to_adapter_idx = defaultdict(list) + + # input is list of (module_map, lora_config) tuples + # convert into dict[k][param_name] -> list of tensors + for idx, (module_map, lora_config) in enumerate(adapters): + for weight_name, data in module_map.items(): + weight_name_to_adapter_idx[weight_name].append(idx) + for k, (param_data, param_name) in data.items(): + module_maps[weight_name][k][param_name].append(param_data) + lora_configs.append(lora_config) + + # validate lora configs are compatible + _validate_lora_configs(lora_configs) + + # merge tensors for each module such that we have a single ModuleMap: + # dict[k] -> merged tensor + merged_module_map: "ModuleMap" = defaultdict(dict) + for weight_name, data in module_maps.items(): + indices = weight_name_to_adapter_idx[weight_name] + param_weights = weights[indices] + for k, param_data in data.items(): + for param_name, tensors in param_data.items(): + merged_tensor = merge_strategy.merge(tensors, param_weights) + merged_module_map[weight_name][k] = (merged_tensor, param_name) + + # merge lora configs + merged_lora_config = _merge_lora_configs(lora_configs) + + return merged_module_map, merged_lora_config + + +def _validate_lora_configs(lora_configs: List["LoraConfig"]): + # check that all configs have the same rank + ranks = set(lora_config.r for lora_config in lora_configs) + if len(ranks) > 1: + raise ValueError( + f"unable to merge adapters, lora configs have different ranks: {ranks}" + ) + + if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs): + raise ValueError( + "unable to merge adapters, lora configs have no target modules" + ) + + +def _merge_lora_configs(lora_configs: List["LoraConfig"]) -> "LoraConfig": + merged_lora_config = copy.copy(lora_configs[0]) + + # merge target modules as a union operation + merged_target_modules = sorted( + set( + module + for lora_config in lora_configs + for module in lora_config.target_modules + ) + ) + merged_lora_config.target_modules = merged_target_modules + + return merged_lora_config diff --git a/server/text_generation_server/utils/merges/utils.py b/server/text_generation_server/utils/merges/utils.py new file mode 100644 index 00000000..d9ad3278 --- /dev/null +++ b/server/text_generation_server/utils/merges/utils.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# From: https://github.com/huggingface/peft/pull/1364 +# Copyright 2024-present the HuggingFace Inc. team. +# Modifications by Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import torch + + +def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + """ + mask = torch.zeros_like(tensor).reshape(-1) + k = int(density * tensor.reshape(-1).shape[0]) + top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True) + mask[top_k[1]] = 1 + return tensor * mask.reshape(tensor.shape) + + +def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor: + """ + Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction + `density`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density)) + pruned_tensor = tensor * mask + if rescale: + torch.div(input=pruned_tensor, other=density) + return pruned_tensor + + +def prune( + tensor: torch.Tensor, + density: float, + method: Literal["magnitude", "random"], + rescale: bool = False, +) -> torch.Tensor: + """ + Prune the values of task tensors based on the `method`. + + Args: + tensor (`torch.Tensor`):The tensor to prune. + density (`float`):The fraction of values to preserve. Should be in [0,1]. + method (`str`):The method to use to prune. Should be one of ["magnitude", "random"]. + rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor. + """ + if density >= 1: + return tensor + elif density < 0: + raise ValueError("Density should be >= 0, got {density}") + if method == "magnitude": + return magnitude_based_pruning(tensor, density) + elif method == "random": + return random_pruning(tensor, density, rescale=rescale) + else: + raise ValueError(f"Unknown method {method}") + + +def calculate_majority_sign_mask( + tensor: torch.Tensor, method: Literal["total", "frequency"] = "total" +): + """ + Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0. + + Args: + tensor (`torch.Tensor`):The tensor to get the mask from. + method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"]. + """ + + sign = tensor.sign() + if method == "total": + sign_magnitude = (sign * tensor.abs()).sum(dim=0) + elif method == "frequency": + sign_magnitude = sign.sum(dim=0) + else: + raise RuntimeError(f'Unimplemented mask method "{method}"') + majority_sign = torch.where(sign_magnitude >= 0, 1, -1) + return sign == majority_sign + + +def disjoint_merge(task_tensors, majority_sign_mask): + mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0) + num_params_preserved = majority_sign_mask.sum(dim=0) + return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0) diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 48ca264b..0ea89267 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -1,5 +1,5 @@ import os -import json +from typing import Union from loguru import logger import torch @@ -43,3 +43,26 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): model.save_pretrained(cache_dir, safe_serialization=True) model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir) + + +def download_peft( + model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool +): + torch_dtype = torch.float16 + try: + _model = AutoPeftModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + except Exception: + _model = AutoPeftModelForSeq2SeqLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + logger.info("Peft model downloaded.") diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py new file mode 100644 index 00000000..f5961102 --- /dev/null +++ b/server/text_generation_server/utils/segments.py @@ -0,0 +1,66 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/segments.py +# License: Apache License Version 2.0, January 2004 + +from typing import List, Tuple, Union + +import torch + + +def find_segments( + adapter_indices: Union[torch.Tensor, List[int]] +) -> Tuple[List[int], List[int]]: + segments = [0] + segment_indices = [] + + if isinstance(adapter_indices, torch.Tensor): + # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first + adapter_indices = adapter_indices.cpu().tolist() + + start_index = 0 + for i in range(1, len(adapter_indices)): + if adapter_indices[i] != adapter_indices[i - 1]: + segments.append(i) + segment_indices.append(adapter_indices[i - 1]) + start_index = i + + # Handle the last segment + if start_index < len(adapter_indices): + segments.append(len(adapter_indices)) + segment_indices.append(adapter_indices[-1]) + + return segments, segment_indices + + +class SegmentConcatBuilder: + def __init__(self): + self.adapter_segment_indices = [] + self.adapter_segment_tensors = [] + + def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]): + # Update adapter segments + if self.adapter_segment_tensors: + # Because we have already processed at least one batch, remove the 0 start index + # from this batch denoting the beginning of the segment, then offset all segment + # positions by the value of the last segment in the previous batch to account for + # the concatenation. + adapter_segments = ( + adapter_segments[1:] + self.adapter_segment_tensors[-1][-1] + ) + + if ( + self.adapter_segment_indices + and self.adapter_segment_indices[-1] == segment_indices[0] + ): + # If the last segment in the previous batch is the same as the first segment in this batch, + # then we merge them together into a single segment. In effect, this means removing it from + # the segment indices of this batch, and extending the segment span by removing the segment + # end index from the previous batch. + segment_indices = segment_indices[1:] + self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1] + + self.adapter_segment_indices.extend(segment_indices) + self.adapter_segment_tensors.append(adapter_segments) + + def build(self) -> Tuple[torch.Tensor, List[int]]: + return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices diff --git a/server/text_generation_server/utils/sgmv.py b/server/text_generation_server/utils/sgmv.py new file mode 100644 index 00000000..e0aec25f --- /dev/null +++ b/server/text_generation_server/utils/sgmv.py @@ -0,0 +1,248 @@ +# Origin: https://github.com/predibase/lorax +# Path: lorax/server/lorax_server/utils/sgmv.py +# License: Apache License Version 2.0, January 2004 + +import os +import warnings +from functools import lru_cache +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +try: + import punica_kernels as _kernels + + HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", "")) +except ImportError: + warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.") + _kernels = None + HAS_SGMV = False + + +MIN_SGMV_RANK = 8 +MIN_RANK_CUSTOM = 16 +MAX_RANK_CUSTOM = 128 +SGMV_BLOCK_SIZE = 16 +BGMV_MAX_RANK = 64 + + +def has_sgmv() -> bool: + return HAS_SGMV + + +def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor: + """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.""" + if not has_sgmv(): + return t + + # tensor parallelism will result in effective rank being divided by world_size, + # so we need to scale the min rank to offset that effect + min_rank = MIN_SGMV_RANK * world_size + + # if we're at or below the min rank, pad up to the min rank + # otherwise, pad to the nearest multiple of the block size + current_rank = t.size(dim) + target_rank = ( + min_rank + if current_rank <= min_rank + else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE + ) + if current_rank == target_rank: + return t + + pad_size = target_rank - current_rank + + # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + pad = [0, 0] * t.dim() + pad[(t.dim() - dim - 1) * 2 + 1] = pad_size + pad = tuple(pad) + + return F.pad(t, pad, mode="constant", value=0.0) + + +def use_cutlass_shrink(lora_rank: int) -> bool: + return lora_rank < MIN_RANK_CUSTOM + + +def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor: + if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM: + return t.transpose(0, 1) + return t + + +# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py +def add_lora_sgmv_cutlass( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.Tensor, + s_end: torch.Tensor, + layer_idx: int, + lora_rank: int, +): + """ + Semantics: + y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i]) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H1]`. + wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\ + Weight matrix shape: `[num_layers, R, H2]`. + s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices. + s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices. + layer_idx: Layer index of the weight matrices. + """ + if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM: + # Custom SGMV shrink only supports rank 16, 32, 64, 128 + _add_lora_sgmv_cutlass_legacy( + y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank + ) + return + + tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device) + tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx) + + +def _add_lora_sgmv_cutlass_legacy( + y: torch.Tensor, + x: torch.Tensor, + wa_ptr: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +): + tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0)) + tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device) + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +@lru_cache(maxsize=1) +def get_tmp_tensor(device: torch.device) -> torch.Tensor: + return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device) + + +@lru_cache(maxsize=32) +def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor: + tmp_size = _kernels.sgmv_cutlass_tmp_size(size) + return torch.empty((tmp_size,), dtype=torch.uint8, device=device) + + +def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor: + return torch.empty((size,), dtype=torch.uint8, device=device) + + +def get_tmp_expand_size(size: int) -> int: + return _kernels.sgmv_cutlass_tmp_size(size) + + +def get_tmp_tensors( + nsegments: int, lora_rank: int, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: + if use_cutlass_shrink(lora_rank) and has_sgmv(): + tmp = get_tmp_tensor_for_size(nsegments, device) + return tmp, tmp + else: + tmp_shrink = get_tmp_tensor(device) + tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device) + return tmp_shrink, tmp_expand + + +def lora_a_sgmv_cutlass( + x: torch.Tensor, + tmp: torch.Tensor, + wa_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, + lora_rank: int, +) -> torch.Tensor: + v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device) + if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM: + _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + else: + _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx) + return v + + +def lora_b_sgmv_cutlass( + y: torch.Tensor, + v: torch.Tensor, + tmp: torch.Tensor, + wb_ptr: torch.Tensor, + s_start: torch.IntTensor, + s_end: torch.IntTensor, + layer_idx: int, +): + _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx) + + +""" +Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + +Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + v: Shape: `[B, R]`. Temporary vector. + x: Shape: `[B, H1]`. Input vectors. + wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices. + wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. +""" + + +def add_lora_a_bgmv( + v: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0) + + +def add_lora_b_bgmv( + y: torch.Tensor, + v: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, +): + _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0) + + +def segmented_matmul( + y: torch.Tensor, + x: torch.Tensor, + w: List[torch.Tensor], + b: List[torch.Tensor], + s_start: torch.IntTensor, + s_end: torch.IntTensor, +): + for i in range(len(w)): + if s_end[i] - s_start[i] <= 0: + continue + + xi = x[s_start[i] : s_end[i]] + wi = w[i] + bi = b[i] + y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)