From 73d84c6ee5648def2a6cd77b810edeca9933c1ab Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 May 2023 10:35:20 +0200 Subject: [PATCH] Hotfixes for santacoder/bigcode. (#294) # What does this PR do? Hotfixes: - Uses `model_type`=`gpt_bigcode` for more general usage. - Hotfixes linked lm_head vs wte_embedding (safetensors file do not contain the key, correctly when the file is sharded, where as pytorch copies the tensor) Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Co-authored-by: Ubuntu Co-authored-by: OlivierDehaene --- server/text_generation_server/models/__init__.py | 13 ++++++++++++- .../models/flash_santacoder.py | 3 +++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index e02be3de..ec990fde 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -99,7 +99,7 @@ def get_model( else: return Galactica(model_id, revision, quantize=quantize) - if "bigcode" in model_id: + if model_id.startswith("bigcode/"): if sharded: if not FLASH_ATTENTION: raise NotImplementedError( @@ -113,6 +113,17 @@ def get_model( config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type + if model_type == "gpt_bigcode": + if sharded: + if not FLASH_ATTENTION: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") + ) + return FlashSantacoderSharded(model_id, revision, quantize=quantize) + else: + santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder + return santacoder_cls(model_id, revision, quantize=quantize) + if model_type == "bloom": if sharded: return BLOOMSharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index c463ee98..afe4eba5 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -376,6 +376,9 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor + + model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) + uninitialized_parameters = [] for n, p in model.named_parameters(): if p.data.device == torch.device("meta"):