diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 301acb6b..b741a84c 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -150,6 +150,17 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e + else: + # Try to load as a local PEFT model + try: + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE