From 91e674bb85760a19afb509cc0010d46b090183fd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 May 2023 11:32:25 +0200 Subject: [PATCH] Lifting check_unitialized. (#325) # What does this PR do? Lifting check_unitialized. Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/text_generation_server/models/bloom.py | 9 --------- .../models/flash_llama.py | 18 ------------------ .../models/flash_neox.py | 9 --------- .../models/flash_santacoder.py | 11 ----------- .../text_generation_server/models/galactica.py | 9 --------- .../text_generation_server/models/gpt_neox.py | 9 --------- server/text_generation_server/models/model.py | 11 +++++++++++ server/text_generation_server/models/opt.py | 9 --------- server/text_generation_server/models/t5.py | 9 --------- 9 files changed, 11 insertions(+), 83 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 877acb00..f6a69031 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -238,15 +238,6 @@ class BLOOMSharded(BLOOM): if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 5f47cf66..0b63f904 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -139,15 +139,6 @@ class FlashLlama(FlashCausalLM): del value - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) @@ -315,14 +306,5 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b3e1876f..168c9195 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -152,13 +152,4 @@ class FlashNeoXSharded(FlashNeoX): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index afe4eba5..51a8998b 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -376,17 +376,6 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 4f94b348..c6dd4c33 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -365,15 +365,6 @@ class GalacticaSharded(Galactica): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 2d42e0b0..215bb2b6 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -215,15 +215,6 @@ class GPTNeoxSharded(CausalLM): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 4c85c952..03f14013 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -32,6 +32,7 @@ class Model(ABC): self.decode_buffer = decode_buffer self.rank = rank self.world_size = world_size + self.check_initialized() @property def info(self) -> InfoResponse: @@ -99,3 +100,13 @@ class Model(ABC): return token_text, None, None else: return "", offset, token_offset + + def check_initialized(self): + uninitialized_parameters = [] + for n, p in self.model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" + ) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 44f15df3..8d856b10 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -212,15 +212,6 @@ class OPTSharded(OPT): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 381617b7..b5e7710d 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -222,15 +222,6 @@ class T5Sharded(Seq2SeqLM): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids,