import torch import torch.distributed from typing import List, Optional, Tuple from accelerate import init_empty_weights from safetensors import safe_open from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig, ) from text_generation_server.models import Seq2SeqLM from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) from transformers.models.t5.parallel_layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, ) HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb from bitsandbytes.nn import Int8Params except ImportError as e: HAS_BITS_AND_BYTES = False class T5Sharded(Seq2SeqLM): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 else: device = torch.device("cpu") dtype = torch.float32 tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) config = AutoConfig.from_pretrained( model_id, revision=revision, tp_parallel=True ) tokenizer.bos_token_id = config.decoder_start_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): model = AutoModelForSeq2SeqLM.from_config(config) torch.distributed.barrier(group=self.process_group) self.load_weights( model, filenames, quantize=quantize, device=device, dtype=dtype, rank=rank, world_size=world_size, ) torch.distributed.barrier(group=self.process_group) super(Seq2SeqLM, self).__init__( model=model, tokenizer=tokenizer, requires_padding=True, dtype=dtype, device=device, rank=rank, world_size=world_size, ) @staticmethod def load_weights( model, filenames: List[str], quantize: Optional[str], device: torch.device, dtype: torch.dtype, rank: int, world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( file, framework="pt", device=str(device) if quantize is None else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) module = model.get_submodule(module_name) current_parameter_tensor = parameters.get(name, None) slice_ = f.get_slice(name) if isinstance(module, TensorParallelColumnLinear): size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": size = slice_.get_shape()[1] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[:, start:stop] else: tensor = slice_[:] # XXX: Hack for Rowlinear to add the bias only once. if rank != 0: tensor = torch.zeros_like(tensor) elif isinstance(module, TensorParallelEmbedding): size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] elif name == "lm_head.weight": size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] elif "relative_attention_bias.weight" in name: size = slice_.get_shape()[1] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[:, start:stop] else: try: tensor = slice_[:] except: tensor = f.get_tensor(name) if ( current_parameter_tensor is not None and current_parameter_tensor.shape != tensor.shape ): raise ValueError( f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" ) tensor = tensor.contiguous() # See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71 if module_name.endswith("wo"): tensor = tensor.to(torch.float32) else: tensor = tensor.to(dtype) if quantize == "bitsandbytes" and not module_name.endswith("wo"): if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " "or you don't have a GPU.\n" "You can install it with `pip install bitsandbytes`." ) if ( type(module) in [TensorParallelRowLinear, TensorParallelColumnLinear] and param_name == "weight" ): tensor = Int8Params( tensor, has_fp16_weights=False, requires_grad=False, ).to(device) state = bnb.MatmulLtState() state.threshold = 6.0 state.has_fp16_weights = False state.memory_efficient_backward = False state.use_pool = True state.CB = tensor.CB state.SCB = tensor.SCB tensor.CB = None tensor.SCB = None def replace_linear(state): def linear(input, weight, bias): out = bnb.matmul( input, weight, state=state, threshold=state.threshold, bias=bias, ) if state.CB is not None: # we converted 8-bit row major to turing/ampere format # in the first inference pass # we no longer need the row-major weight del state.CB weight.data = state.CxB return out return linear module.linear = replace_linear(state) elif quantize == "gptq" and not module_name.endswith("wo"): raise NotImplementedError( "`gptq` is not implemented for now" ) elif quantize is None: tensor = tensor.to(device) else: raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor else: module._buffers[param_name] = tensor def forward( self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask: Optional, encoder_last_hidden_state: Optional, past_key_values: Optional = None, ) -> Tuple[ torch.Tensor, torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: # Model Forward outputs = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_last_hidden_state, past_key_values=past_key_values, use_cache=True, ) # Logits are sharded, so we need to gather them logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) logits = torch.cat(logits, dim=2) return ( logits, outputs.encoder_last_hidden_state, outputs.past_key_values, )