From 214c06f51005b5be2a06963bb3ad1281abc2bc18 Mon Sep 17 00:00:00 2001 From: cdawg Date: Thu, 20 Jul 2023 13:53:08 +0200 Subject: [PATCH] Add trust_remote_code to quantize script (#647) # What does this PR do? Fixes a bug appeared with MR #587 fixing issue #552. See the discussion in #552. With MR #587 the trust_remote_code variable is not passed to AutoModelForCausalLM, but is found in the function signature. This prevents models like falcon to be quantized, because trust_remote_code is required. This MR fixes the issue. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [X] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --- server/text_generation_server/utils/gptq/quantize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index d182456f..bee1e446 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -864,7 +864,8 @@ def quantize( ) with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16, + trust_remote_code=trust_remote_code) model = model.eval() print("LOADED model")