From 880a76eed5f058043367d9643be8a498b286bde2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 12 Apr 2023 17:18:08 +0200 Subject: [PATCH] feat(server): support sharded santacoder (#167) --- .../text_generation_server/models/__init__.py | 14 +- server/text_generation_server/models/bloom.py | 6 +- .../custom_modeling/flash_llama_modeling.py | 2 +- .../custom_modeling/flash_neox_modeling.py | 7 +- .../flash_santacoder_modeling.py | 209 +++++++++++++++- .../models/flash_neox.py | 14 +- .../models/flash_santacoder.py | 234 +++++++++++++++++- .../models/galactica.py | 6 +- .../text_generation_server/models/gpt_neox.py | 6 +- server/text_generation_server/models/opt.py | 6 +- server/text_generation_server/models/t5.py | 6 +- 11 files changed, 462 insertions(+), 48 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index c04ae11..368060a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -18,8 +18,11 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - from text_generation_server.models.flash_santacoder import FlashSantacoder from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded + from text_generation_server.models.flash_santacoder import ( + FlashSantacoder, + FlashSantacoderSharded, + ) FLASH_ATTENTION = torch.cuda.is_available() except ImportError: @@ -49,6 +52,7 @@ if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) __all__.append(FlashSantacoder) + __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) __all__.append(FlashLlamaSharded) @@ -78,9 +82,13 @@ def get_model( else: return Galactica(model_id, revision, quantize=quantize) - if "santacoder" in model_id: + if "bigcode" in model_id: if sharded: - raise NotImplementedError("sharded is not supported for Santacoder") + if not FLASH_ATTENTION: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") + ) + return FlashSantacoderSharded(model_id, revision=revision) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder return santacoder_cls(model_id, revision, quantize) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index efcc9e0..731a985 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -93,10 +93,11 @@ class BLOOMSharded(BLOOM): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, device=device, decode_buffer=1 @@ -108,6 +109,7 @@ class BLOOMSharded(BLOOM): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -157,7 +159,7 @@ class BLOOMSharded(BLOOM): f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 228529c..508b774 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -373,7 +373,7 @@ class LlamaMLP(nn.Module): x, approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] - else None, + else "none", ) ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 4ff1761..16fd409 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -376,7 +376,12 @@ class FlashMLP(nn.Module): self.act = ( ACT2FN[act] if "gelu" not in act - else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + else lambda x: torch.nn.functional.gelu( + x, + approximate="tanh" + if act in ["gelu_fast", "gelu_pytorch_tanh"] + else "none", + ) ) if process_group is None: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 29c4a5c..8679826 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -1,6 +1,8 @@ import torch import torch.distributed +import torch.nn.functional as F + from torch import nn from transformers.activations import ACT2FN @@ -65,6 +67,127 @@ class FastLinear(nn.Linear): return torch.matmul(input, self.weight) +class TensorParallelColumnLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + +class TensorParallelRowLinear(FastLinear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + reduce=True, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + self.reduce = reduce + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = super(TensorParallelRowLinear, self).forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + reduce=True, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.reduce = reduce + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 + # translate for [0, self.max_id - self.min_id[ + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) + out = super().forward(input) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) + return out + + class FlashMQAttention(torch.nn.Module): def __init__( self, @@ -80,10 +203,16 @@ class FlashMQAttention(torch.nn.Module): self.softmax_scale = self.head_size ** (-0.5) if process_group is None: - self.attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) + self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) self.c_proj = FastLinear(hidden_size, hidden_size) else: - raise NotImplementedError + self.num_heads = self.num_heads // process_group.size() + self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) + self.c_proj = TensorParallelRowLinear( + hidden_size, + hidden_size, + process_group=process_group, + ) def forward( self, @@ -94,10 +223,12 @@ class FlashMQAttention(torch.nn.Module): layer_past_present_indices, cu_seqlens_q, ): - qkv = self.attn(hidden_states) + qkv = self.c_attn(hidden_states) # Split query from key_value - query, key_value = qkv.split([self.hidden_size, 2 * self.head_size], dim=1) + query, key_value = qkv.split( + [self.head_size * self.num_heads, 2 * self.head_size], dim=1 + ) # Prepare query and key_value for indexing query = query.view(-1, self.num_heads, self.head_size) @@ -171,7 +302,7 @@ class MLP(nn.Module): x, approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] - else None, + else "none", ) ) @@ -179,7 +310,16 @@ class MLP(nn.Module): self.c_fc = FastLinear(hidden_size, intermediate_size) self.c_proj = FastLinear(intermediate_size, hidden_size) else: - raise NotImplementedError + self.c_fc = TensorParallelColumnLinear( + hidden_size, + intermediate_size, + process_group=process_group, + ) + self.c_proj = TensorParallelRowLinear( + intermediate_size, + hidden_size, + process_group=process_group, + ) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) @@ -246,11 +386,30 @@ class FlashSantacoderModel(nn.Module): super().__init__() self.config = config + self.process_group = process_group + self.tp_embeddings = False if process_group is not None: - raise NotImplementedError + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True - self.wte = nn.Embedding(config.vocab_size, config.hidden_size) - self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) + if self.tp_embeddings: + self.wte = TensorParallelEmbedding( + config.vocab_size, + config.hidden_size, + reduce=False, + process_group=process_group, + ) + self.wpe = TensorParallelEmbedding( + config.max_position_embeddings, + config.hidden_size, + reduce=False, + process_group=process_group, + ) + else: + self.wte = nn.Embedding(config.vocab_size, config.hidden_size) + self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.h = nn.ModuleList( [ @@ -273,9 +432,12 @@ class FlashSantacoderModel(nn.Module): self.num_heads = self.h[0].attn.num_heads def post_load_weights(self): + if self.tp_embeddings: + self.wte.add_null_idx() + self.wpe.add_null_idx() for layer in self.h: layer: Block - layer.attn.attn.transpose_weight() + layer.attn.c_attn.transpose_weight() layer.attn.c_proj.transpose_weight() layer.mlp.c_fc.transpose_weight() layer.mlp.c_proj.transpose_weight() @@ -289,6 +451,8 @@ class FlashSantacoderModel(nn.Module): past_key_values=None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) + if self.tp_embeddings: + torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: @@ -335,7 +499,14 @@ class FlashSantacoderForCausalLM(nn.Module): self.transformer = FlashSantacoderModel(config, process_group) - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) + if self.transformer.tp_embeddings: + self.lm_head = FastLinear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) def post_load_weights(self): self.transformer.post_load_weights() @@ -352,4 +523,18 @@ class FlashSantacoderForCausalLM(nn.Module): hidden_states, present = self.transformer( input_ids, position_ids, cu_seqlens, max_s, past_key_values ) - return self.lm_head(hidden_states), present + logits = self.lm_head(hidden_states) + + if self.transformer.tp_embeddings: + # Logits are sharded, so we need to gather them + world_logits = [ + torch.empty_like(logits) for _ in range(self.transformer.tp_world_size) + ] + torch.distributed.all_gather( + world_logits, logits, group=self.transformer.process_group + ) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + + return logits, present diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b93d9f7..a8b3846 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -5,7 +5,7 @@ from accelerate import init_empty_weights from opentelemetry import trace from safetensors import safe_open from transformers import AutoTokenizer, AutoConfig -from typing import Optional, Tuple, List +from typing import Optional, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_neox_modeling import ( @@ -63,13 +63,13 @@ class FlashNeoXSharded(FlashNeoX): self.load_weights( model, filenames, - quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) model.post_load_weights() - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -80,16 +80,14 @@ class FlashNeoXSharded(FlashNeoX): def load_weights( model, filenames: List[str], - quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: - with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" - ) as f: + with safe_open(file, framework="pt", device=str(device)) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) module = model.get_submodule(module_name) @@ -142,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index e10d259..39381e9 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -3,15 +3,20 @@ import torch.distributed from accelerate import init_empty_weights from opentelemetry import trace +from safetensors import safe_open from pathlib import Path -from transformers import AutoTokenizer, AutoConfig +from transformers import AutoTokenizer, GPT2Config from typing import Optional, List from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, ) from text_generation_server.utils import ( + initialize_torch_distributed, weight_files, download_weights, weight_hub_files, @@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM): model_id, revision=revision, padding_side="left", truncation_side="left" ) - config = AutoConfig.from_pretrained( + config = GPT2Config.from_pretrained( model_id, revision=revision, - trust_remote_code=True, # Needed as the config is not part of Transformers ) # We do not use from_pretrained as we modified the model internal module layout @@ -54,12 +58,9 @@ class FlashSantacoder(FlashCausalLM): model = FlashSantacoderForCausalLM(config) self.load_weights( - model, - filenames, - device, - dtype, + model, filenames, device, dtype, config.architectures[0].startswith("GPT2") ) - self.model = model.eval().to(device).to(dtype) + self.model = model.eval() super(FlashCausalLM, self).__init__( tokenizer=tokenizer, device=device, decode_buffer=1 @@ -71,6 +72,7 @@ class FlashSantacoder(FlashCausalLM): filenames: List[Path], device: torch.device, dtype: torch.dtype, + transpose: bool, ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") @@ -81,9 +83,9 @@ class FlashSantacoder(FlashCausalLM): # Fused qkv if "q_attn.weight" in key or "kv_attn.weight" in key: - final_key = layer_name + ".attn.weight" + final_key = layer_name + ".c_attn.weight" elif "q_attn.bias" in key or "kv_attn.bias" in key: - final_key = layer_name + ".attn.bias" + final_key = layer_name + ".c_attn.bias" else: final_key = key @@ -97,18 +99,19 @@ class FlashSantacoder(FlashCausalLM): current_parameter_tensor = None if current_parameter_tensor is not None: - if ( + if transpose and ( "c_fc.weight" in key or "c_proj.weight" in key or "q_attn.weight" in key or "kv_attn.weight" in key + or "c_attn.weight" in key ): # Tranpose as we use nn.Linear instead of Conv1D value = value.T if current_parameter_tensor.device == torch.device("meta"): # Init qkv - if "attn.weight" in final_key: + if "c_attn.weight" in final_key: module._parameters[param_name] = value.new_empty( ( model.transformer.head_size @@ -116,7 +119,7 @@ class FlashSantacoder(FlashCausalLM): value.shape[1], ) ) - elif "attn.bias" in final_key: + elif "c_attn.bias" in final_key: module._parameters[param_name] = value.new_empty( ( model.transformer.head_size @@ -156,3 +159,208 @@ class FlashSantacoder(FlashCausalLM): return self.tokenizer.decode( generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False ) + + +class FlashSantacoderSharded(FlashSantacoder): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashSantacoderSharded is only available on GPU") + + if quantize: + raise NotImplementedError( + "FlashSantacoderSharded does not support quantization" + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left", truncation_side="left" + ) + + config = GPT2Config.from_pretrained( + model_id, + revision=revision, + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashSantacoderForCausalLM(config, self.process_group) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + device=device, + dtype=dtype, + rank=self.rank, + world_size=self.world_size, + transpose=config.architectures[0].startswith("GPT2"), + ) + self.model = model.eval() + torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + device: torch.device, + dtype: torch.dtype, + rank: int, + world_size: int, + transpose: bool, + ): + for file in filenames: + with safe_open(file, framework="pt", device=str(device)) as f: + for key in f.keys(): + slice_ = f.get_slice(key) + + layer_name = ".".join(key.split(".")[:4]) + + # Fused qkv + if "q_attn.weight" in key or "kv_attn.weight" in key: + final_key = layer_name + ".c_attn.weight" + elif "q_attn.bias" in key or "kv_attn.bias" in key: + final_key = layer_name + ".c_attn.bias" + else: + final_key = key + + module_name, param_name = final_key.rsplit(".", 1) + module = model.get_submodule(module_name) + + if isinstance(module, TensorParallelColumnLinear): + dim = 1 if transpose and "weight" in param_name else 0 + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = ( + slice_[start:stop] if dim == 0 else slice_[:, start:stop] + ) + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + dim = 0 if transpose else 1 + size = slice_.get_shape()[dim] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = ( + slice_[start:stop] + if dim == 0 + else slice_[:, start:stop] + ) + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif key == "lm_head.weight" and model.transformer.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(key) + + tensor = tensor.contiguous().to(dtype) + + try: + current_parameter_tensor = module._parameters[param_name] + except KeyError: + current_parameter_tensor = None + + if current_parameter_tensor is not None: + if transpose and ( + "c_fc.weight" in key + or "c_proj.weight" in key + or "q_attn.weight" in key + or "kv_attn.weight" in key + or "c_attn.weight" in key + ): + # Tranpose as we use nn.Linear instead of Conv1D + tensor = tensor.T + + if current_parameter_tensor.device == torch.device("meta"): + # Init qkv + if "c_attn.weight" in final_key: + module._parameters[param_name] = tensor.new_empty( + ( + model.transformer.head_size + * (model.transformer.num_heads + 2), + tensor.shape[1], + ) + ) + elif "c_attn.bias" in final_key: + module._parameters[param_name] = tensor.new_empty( + ( + model.transformer.head_size + * (model.transformer.num_heads + 2) + ) + ) + + # Copy to correct slice + if "q_attn" in key: + size = tensor.shape[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = tensor[start:stop] + module._parameters[param_name][: tensor.shape[0]] = tensor + elif "kv_attn.weight" in key: + module._parameters[param_name][ + model.transformer.head_size + * model.transformer.num_heads : + ] = tensor + elif "kv_attn.bias" in key: + module._parameters[param_name][ + model.transformer.head_size + * model.transformer.num_heads : + ] = tensor + elif "c_attn" in key: + # Slice q_tensor by shard + q_tensor = tensor[: -2 * model.transformer.head_size] + block_size = q_tensor.shape[0] // world_size + start = rank * block_size + stop = (rank + 1) * block_size + q_tensor = q_tensor[start:stop] + + module._parameters[param_name][ + : q_tensor.shape[0] + ] = q_tensor + + # Kv tensor is copied for every shard + kv_tensor = tensor[-2 * model.transformer.head_size :] + module._parameters[param_name][ + q_tensor.shape[0] : + ] = kv_tensor + else: + if current_parameter_tensor.shape != tensor.shape: + raise ValueError( + f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + torch.cuda.empty_cache() + model.post_load_weights() diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 396cc4f..dc78aa8 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -219,10 +219,11 @@ class GalacticaSharded(Galactica): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -235,6 +236,7 @@ class GalacticaSharded(Galactica): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -285,7 +287,7 @@ class GalacticaSharded(Galactica): f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index fb109ed..489615e 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 85f0ac8..8e5527c 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -80,10 +80,11 @@ class OPTSharded(OPT): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(CausalLM, self).__init__( tokenizer=tokenizer, @@ -96,6 +97,7 @@ class OPTSharded(OPT): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -146,7 +148,7 @@ class OPTSharded(OPT): f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 5266eb8..b9f7701 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM): filenames, quantize=quantize, device=device, + dtype=dtype, rank=self.rank, world_size=self.world_size, ) - self.model = model.eval().to(dtype) + self.model = model.eval() torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( tokenizer=tokenizer, @@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM): filenames: List[str], quantize: bool, device: torch.device, + dtype: torch.dtype, rank: int, world_size: int, ): @@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM): f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) - tensor = tensor.contiguous() + tensor = tensor.contiguous().to(dtype) if quantize: if not HAS_BITS_AND_BYTES: