diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a7351a33..be77b42c 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -58,6 +58,9 @@ - local: conceptual/speculation title: Speculation (Medusa, ngram) - local: conceptual/guidance - title: How Guidance Works (via outlines) + title: How Guidance Works (via outlines + - local: conceptual/mutli_lora + title: LoRA (Low-Rank Adaptation) + title: Conceptual Guides diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 9246093e..a77f25a5 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -409,6 +409,14 @@ Options: [env: MAX_CLIENT_BATCH_SIZE=] [default: 4] +``` +## LORA_ADAPTERS +```shell + --lora-adapters + Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request + + [env: LORA_ADAPTERS=] + ``` ## HELP ```shell diff --git a/docs/source/conceptual/multi_lora.md b/docs/source/conceptual/multi_lora.md new file mode 100644 index 00000000..060fc3fe --- /dev/null +++ b/docs/source/conceptual/multi_lora.md @@ -0,0 +1,65 @@ +# LoRA (Low-Rank Adaptation) + +## What is LoRA? + +LoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task. + +LoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed. + +## How is it used? + +LoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA: + +Technically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as: + +- fine-tuning a language model on a small dataset +- fine-tuning a language model on a domain-specific dataset +- fine-tuning a language model on a dataset with limited labels + +## Optimizing Inference with LoRA + +LoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models. + +## Serving multiple LoRA adapters with TGI + +Once a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned. + +In practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset. + +Text Genertaion Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library. + +### Specifying LoRA models + +To use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example: + +```bash +LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia +``` + +In the server logs, you will see the following message: + +```txt +Loading adapter weights into model: predibase/customer_support +Loading adapter weights into model: predibase/dbpedia +``` + +## Generate text + +You can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example: + +```json +curl 127.0.0.1:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "Hello who are you?", + "parameters": { + "max_new_tokens": 40, + "adapter_id": "predibase/customer_support" + } +}' +``` + +> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon. + +An updated tutorial with detailed examples will be published soon. Stay tuned! diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e4d5bb85..45868096 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -449,6 +449,11 @@ struct Args { /// Control the maximum number of inputs that a client can send in a single request #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + + /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during + /// startup that will be available to callers via the `adapter_id` field in a request. + #[clap(long, env)] + lora_adapters: Option, } #[derive(Debug)] @@ -482,6 +487,7 @@ fn shard_manager( max_total_tokens: usize, max_batch_size: Option, max_input_tokens: usize, + lora_adapters: Option, otlp_endpoint: Option, log_level: LevelFilter, status_sender: mpsc::Sender, @@ -612,6 +618,11 @@ fn shard_manager( envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); } + // Lora Adapters + if let Some(lora_adapters) = lora_adapters { + envs.push(("LORA_ADAPTERS".into(), lora_adapters.into())); + } + // If huggingface_hub_cache is some, pass it to the shard // Useful when running inside a docker container if let Some(huggingface_hub_cache) = huggingface_hub_cache { @@ -1048,6 +1059,7 @@ fn spawn_shards( let rope_scaling = args.rope_scaling; let rope_factor = args.rope_factor; let max_batch_size = args.max_batch_size; + let lora_adapters = args.lora_adapters.clone(); thread::spawn(move || { shard_manager( model_id, @@ -1073,6 +1085,7 @@ fn spawn_shards( max_total_tokens, max_batch_size, max_input_tokens, + lora_adapters, otlp_endpoint, max_log_level, status_sender,