diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 877acb00..f6a69031 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -238,15 +238,6 @@ class BLOOMSharded(BLOOM): if name == "word_embeddings.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 5f47cf66..0b63f904 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -139,15 +139,6 @@ class FlashLlama(FlashCausalLM): del value - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) @@ -315,14 +306,5 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b3e1876f..168c9195 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -152,13 +152,4 @@ class FlashNeoXSharded(FlashNeoX): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index afe4eba5..51a8998b 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -376,17 +376,6 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor - model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) - - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - torch.cuda.empty_cache() model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 4f94b348..c6dd4c33 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -365,15 +365,6 @@ class GalacticaSharded(Galactica): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 2d42e0b0..215bb2b6 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -215,15 +215,6 @@ class GPTNeoxSharded(CausalLM): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 4c85c952..03f14013 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -32,6 +32,7 @@ class Model(ABC): self.decode_buffer = decode_buffer self.rank = rank self.world_size = world_size + self.check_initialized() @property def info(self) -> InfoResponse: @@ -99,3 +100,13 @@ class Model(ABC): return token_text, None, None else: return "", offset, token_offset + + def check_initialized(self): + uninitialized_parameters = [] + for n, p in self.model.named_parameters(): + if p.data.device == torch.device("meta"): + uninitialized_parameters.append(n) + if uninitialized_parameters: + raise RuntimeError( + f"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}" + ) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 44f15df3..8d856b10 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -212,15 +212,6 @@ class OPTSharded(OPT): if name == "model.decoder.embed_tokens.weight": model.lm_head._parameters["weight"] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 381617b7..b5e7710d 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -222,15 +222,6 @@ class T5Sharded(Seq2SeqLM): else: module._buffers[param_name] = tensor - uninitialized_parameters = [] - for n, p in model.named_parameters(): - if p.data.device == torch.device("meta"): - uninitialized_parameters.append(n) - if uninitialized_parameters: - raise RuntimeError( - f"found uninitialized parameters in model: {uninitialized_parameters}" - ) - def forward( self, input_ids,