diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 403f46e7..99be6c7e 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -198,6 +198,35 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e + elif (Path(model_id) / "medusa_lm_head.pt").exists(): + # Try to load as a local Medusa model + try: + import json + + medusa_head = Path(model_id) / "medusa_lm_head.pt" + if auto_convert: + medusa_sf = Path(model_id) / "medusa_lm_head.safetensors" + if not medusa_sf.exists(): + utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_config = Path(model_id) / "config.json" + with open(medusa_config, "r") as f: + config = json.load(f) + + model_id = config["base_model_name_or_path"] + revision = "main" + try: + utils.weight_files(model_id, revision, extension) + logger.info( + f"Files for parent {model_id} are already present on the host. " + "Skipping download." + ) + return + # Local files not found + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 8a3bccdd..7be61906 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -71,15 +71,26 @@ class FlashLlama(FlashCausalLM): from text_generation_server.utils.medusa import MedusaModel from huggingface_hub import hf_hub_download import json - - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) + import os + from pathlib import Path + + is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( + "WEIGHTS_CACHE_OVERRIDE", None + ) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + with open(medusa_config, "r") as f: config = json.load(f) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" weights = Weights( [medusa_sf], device, dtype, process_group=self.process_group