diff --git a/README.md b/README.md index 5f943de3..54851c2a 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,8 @@ to power LLMs api-inference widgets. - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13` +- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b) +- [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl): use `--revision pr/26` Other models are supported on a best effort basis using: diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index 291705a6..7445b427 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -10,14 +10,20 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation.models.santacoder import SantaCoder from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded +from text_generation.models.t5 import T5Sharded __all__ = [ "Model", "BLOOM", "BLOOMSharded", "CausalLM", + "Galactica", + "GalacticaSharded", + "GPTNeox", + "GPTNeoxSharded", "Seq2SeqLM", "SantaCoder", + "T5Sharded", "get_model", ] @@ -47,6 +53,11 @@ def get_model( return GPTNeoxSharded(model_id, revision, quantize=quantize) else: return GPTNeox(model_id, revision, quantize=quantize) + elif config.model_type == "t5": + if sharded: + return T5Sharded(model_id, revision, quantize=quantize) + else: + return Seq2SeqLM(model_id, revision, quantize=quantize) elif model_id.startswith("facebook/galactica"): if sharded: return GalacticaSharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation/models/t5.py b/server/text_generation/models/t5.py new file mode 100644 index 00000000..d7241c81 --- /dev/null +++ b/server/text_generation/models/t5.py @@ -0,0 +1,258 @@ +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from accelerate import init_empty_weights +from safetensors import safe_open +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + AutoConfig, +) +from transformers.models.t5.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + +from text_generation.models import Seq2SeqLM +from text_generation.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, +) + +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except Exception as e: + HAS_BITS_AND_BYTES = False + + +class T5Sharded(Seq2SeqLM): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + tokenizer.bos_token_id = config.decoder_start_token_id + + # Only master download weights + if self.master: + download_weights(model_id, revision=revision, extension=".safetensors") + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + if not filenames: + raise ValueError("No safetensors weights found") + + with init_empty_weights(): + model = AutoModelForSeq2SeqLM.from_config(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(Seq2SeqLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "lm_head.weight": + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif "relative_attention_bias.weight" in name: + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor, + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state): + def linear(input, weight, bias): + out = bnb.matmul( + input, + weight, + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out + + return linear + + module.linear = replace_linear(state) + + else: + tensor = tensor.to(device) + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + def forward( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask: Optional, + encoder_last_hidden_state: Optional, + past_key_values: Optional = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], + ]: + # Model Forward + if past_key_values is not None: + decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) + + # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]` + # internally... + if encoder_last_hidden_state is not None: + encoder_last_hidden_state = [encoder_last_hidden_state] + + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_last_hidden_state, + past_key_values=past_key_values, + use_cache=True, + ) + + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) + + return ( + logits, + outputs.encoder_last_hidden_state, + outputs.past_key_values, + )