From 92bb56b0c1038a35f73a6c96c506f6d1c3d7b043 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 31 Jul 2023 10:32:52 +0200 Subject: [PATCH] Local gptq support. (#738) # What does this PR do? Redoes #719 Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- .github/workflows/tests.yaml | 2 +- Dockerfile | 2 +- server/text_generation_server/utils/weights.py | 7 ++++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7e5ba52c..311ee6b9 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,7 +33,7 @@ jobs: - name: Install Rust uses: actions-rs/toolchain@v1 with: - toolchain: 1.65.0 + toolchain: 1.71.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/Dockerfile b/Dockerfile index 34109d02..679fb48a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.70 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 0330402d..dcab6296 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError @@ -221,8 +222,12 @@ class Weights: return bits, groupsize def _set_gptq_params(self, model_id): + filename = "quantize_config.json" try: - filename = hf_hub_download(model_id, filename="quantize_config.json") + if not os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) with open(filename, "r") as f: data = json.load(f) self.gptq_bits = data["bits"]