From 8a5f5649429f8df0ae5b86e485879c3d21d255f6 Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Thu, 31 Aug 2023 21:15:14 +0200 Subject: [PATCH] Fix Falcon weight mapping for H2O.ai checkpoints (#953) # What does this PR do? During the safetensor conversion, duplicate weights are removed. However, which of the duplicates gets removed, differs per checkpoint. In some, like `h2oai/h2ogpt-oig-oasst1-falcon-40b`, the weight `transformer.word_embeddings.weightSafetensor` gets removed. In others, `lm_head.weight` gets removed. Long story long, we need to support both. Originally, f018143 mapped `lm_head` to `word_embeddings`. Then ac736fd switched this around. This commit merges them and allows for both. ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? @Narsil, you wrote both commits I referenced in this PR. I think you'll understand this change :) --- server/text_generation_server/models/flash_rw.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 2fc7c53d..195b3883 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -54,7 +54,10 @@ class FlashRWSharded(FlashCausalLM): device, dtype, process_group=self.process_group, - aliases={"lm_head.weight": ["transformer.word_embeddings.weight"]}, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, ) config.quantize = quantize