From e12c34bd25c8297c1df33493cbcedef8ffa1306f Mon Sep 17 00:00:00 2001 From: Traun Leyden Date: Thu, 23 Nov 2023 12:56:17 +0100 Subject: [PATCH] Load PEFT weights from local directory (#1260) # What does this PR do? Enables PEFT weights to be loaded from a local directory, as opposed to a hf hub repository. It is a continuation of the work in PR https://github.com/huggingface/text-generation-inference/pull/762 Fixes #1259 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? **Yes but I don't know how to run the tests for this repo, and it doesn't look like this code is covered anyway** - [x] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. **Yes, @Narsil asked for a PR in [this comment](https://github.com/huggingface/text-generation-inference/pull/762#issuecomment-1728089505)** - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). **I didn't see any documentation added to the [original PR](https://github.com/huggingface/text-generation-inference/pull/762), and am not sure where this belongs. Let me know and I can add some** - [x] Did you write any new necessary tests? **I didn't see any existing test coverage for this python module** ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --------- Co-authored-by: Nicolas Patry --- server/text_generation_server/cli.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 301acb6b..b741a84c 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -150,6 +150,17 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e + else: + # Try to load as a local PEFT model + try: + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) + utils.weight_files(model_id, revision, extension) + return + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE