# Train Medusa This tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general. ## What are the benefits of training a Medusa model? Training Medusa heads can greatly improve the speed of generation. Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training. One of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain. If you train Medusa on a dataset that is very different from the one you will use in production then the model will not be able to predict the future tokens accurately and consequently the speedup will be minimal or non-existent. ## Self-distillation (Generating data for training) There are many methods for preparing data for training, but one of the easiest and most effective ways is to "self-distill" the data. This means that you can use the same model to generate the data that you will use to train the model. Essentially, you prompt the model with a similar input to what you will use in production and the model will generate the output. We'll use this output to help train the medusa heads to predict the `n+1`, `n+2`, `n+3`, etc tokens in the sequence. ## Training The original implementation of Medusa is available at [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) and we'll follow a very similar process to train the model as described on the original repository. ### Getting Started There are two methods for training the model: - `torchrun` that is a wrapper around `torch.distributed.launch` - a forked version of `axlotl` that supports Medusa In this tutorial we'll use `torchrun` to train the model as it is the most straightforward way to train the model but similar steps can be followed to train the model using `axlotl` if you prefer. ### Training with `torchrun` ```bash mkdir medusa-training cd medusa-training pyenv install 3.10 pyenv local 3.10 uv venv -p 3.10 source .venv/bin/activate ``` Now lets clone the original `Medusa` repository and install the library. ```bash git clone https://github.com/FasterDecoding/Medusa.git cd Medusa pip install -e . ``` Next we'll need some data to train on, we can use the `ShareGPT_Vicuna_unfiltered` dataset that is available on the Hugging Face Hub. ```bash apt install git-lfs git lfs install git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered ``` Currently our directory structure looks like this: ```bash . ├── assets ├── CITATION.cff ├── create_data.py ├── data_generation ├── deepspeed.json ├── last_run_prepared ├── LICENSE ├── llm_judge ├── medusa ├── medusa_llm.egg-info ├── mistral.json ├── notebooks ├── pyproject.toml ├── README.md ├── ROADMAP.md ├── scripts ├── ShareGPT_Vicuna_unfiltered │ ├── README.md │ ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json │ └── ShareGPT_V4.3_unfiltered_cleaned_split.json ├── simple_gradio_interface.py ├── tiny-llama.json └── vicuna_7b_qlora_stage1 ``` ## Start Training Now the lets generate the data and start training the model. This process will take a while since we are generating data from the model. First make sure you have an instance of TGI running with the model you want to use for self-distillation. ```bash model=HuggingFaceH4/zephyr-7b-beta volume=/home/ubuntu/.cache/huggingface/hub/ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model ``` Now we can generate the data using the `create_data.py` script. ```bash python create_data.py \ --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \ --output-filename zephyr_self_distill.json ``` At this point our terminal should look like this: