From f26dfd0dc1f9e930209bee3c585236278baeae05 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 11 Apr 2023 19:16:41 +0200 Subject: [PATCH] feat(server): support OPT models (#55) OPT models do not all have a `tokenizer.json` file on the hub at the moment. Can't merge for now. --- README.md | 7 +- launcher/src/main.rs | 20 +- .../text_generation_server/models/__init__.py | 26 +- server/text_generation_server/models/bloom.py | 19 +- .../models/galactica.py | 24 +- .../text_generation_server/models/gpt_neox.py | 2 +- server/text_generation_server/models/opt.py | 224 ++++++++++++++++++ server/text_generation_server/models/t5.py | 2 +- server/text_generation_server/utils/hub.py | 5 - 9 files changed, 270 insertions(+), 59 deletions(-) create mode 100644 server/text_generation_server/models/opt.py diff --git a/README.md b/README.md index bc77fd4..60c1a6b 100644 --- a/README.md +++ b/README.md @@ -54,11 +54,12 @@ to power LLMs api-inference widgets. ## Optimized architectures - [BLOOM](https://huggingface.co/bigscience/bloom) -- [Galactica](https://huggingface.co/facebook/galactica-120b) -- [SantaCoder](https://huggingface.co/bigcode/santacoder) -- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) - [FLAN-T5](https://huggingface.co/google/flan-t5-xxl) +- [Galactica](https://huggingface.co/facebook/galactica-120b) +- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) - [Llama](https://github.com/facebookresearch/llama) +- [OPT](https://huggingface.co/facebook/opt-66b) +- [SantaCoder](https://huggingface.co/bigcode/santacoder) Other architectures are supported on a best effort basis using: diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a598a8b..2b152be 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -349,8 +349,8 @@ fn main() -> ExitCode { Err(TryRecvError::Empty) => { sleep(Duration::from_millis(100)); } - Ok(ShardStatus::Failed((rank, err))) => { - tracing::error!("Shard {} failed to start:\n{}", rank, err); + Ok(ShardStatus::Failed(rank)) => { + tracing::error!("Shard {} failed to start.", rank); shutdown_shards(shutdown, &shutdown_receiver); return ExitCode::FAILURE; } @@ -457,8 +457,8 @@ fn main() -> ExitCode { let mut exit_code = ExitCode::SUCCESS; while running.load(Ordering::SeqCst) { - if let Ok(ShardStatus::Failed((rank, err))) = status_receiver.try_recv() { - tracing::error!("Shard {rank} failed:\n{err}"); + if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() { + tracing::error!("Shard {rank} failed."); exit_code = ExitCode::FAILURE; break; }; @@ -488,7 +488,7 @@ fn main() -> ExitCode { #[derive(Debug)] enum ShardStatus { Ready, - Failed((usize, String)), + Failed(usize), } #[allow(clippy::too_many_arguments)] @@ -627,9 +627,7 @@ fn shard_manager( tracing::error!("Please install it with `make install-server`") } } - status_sender - .send(ShardStatus::Failed((rank, err.to_string()))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } }; @@ -658,11 +656,7 @@ fn shard_manager( loop { // Process exited if p.poll().is_some() { - let mut err = String::new(); - p.stderr.take().unwrap().read_to_string(&mut err).unwrap(); - status_sender - .send(ShardStatus::Failed((rank, err))) - .unwrap(); + status_sender.send(ShardStatus::Failed(rank)).unwrap(); return; } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 1e06b6d..c04ae11 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,4 +1,3 @@ -import os import torch from loguru import logger @@ -11,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM +from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.gpt_neox import GPTNeoxSharded @@ -36,7 +36,11 @@ __all__ = [ "GalacticaSharded", "GPTNeoxSharded", "Seq2SeqLM", + "Galactica", + "GalacticaSharded", "SantaCoder", + "OPT", + "OPTSharded", "T5Sharded", "get_model", ] @@ -48,9 +52,11 @@ if FLASH_ATTENTION: __all__.append(FlashLlama) __all__.append(FlashLlamaSharded) -FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention CUDA kernels to be installed.\n" \ - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " \ - "or install flash attention with `cd server && make install install-flash-attention`" +FLASH_ATT_ERROR_MESSAGE = ( + "{} requires Flash Attention CUDA kernels to be installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" +) # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -64,7 +70,7 @@ torch.set_grad_enabled(False) def get_model( - model_id: str, revision: Optional[str], sharded: bool, quantize: bool + model_id: str, revision: Optional[str], sharded: bool, quantize: bool ) -> Model: if "facebook/galactica" in model_id: if sharded: @@ -100,13 +106,17 @@ def get_model( if sharded: if FLASH_ATTENTION: return FlashLlamaSharded(model_id, revision, quantize=quantize) - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama") - ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama")) else: llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM return llama_cls(model_id, revision, quantize=quantize) + if config.model_type == "opt": + if sharded: + return OPTSharded(model_id, revision, quantize=quantize) + else: + return OPT(model_id, revision, quantize=quantize) + if model_type == "t5": if sharded: return T5Sharded(model_id, revision, quantize=quantize) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index ce3895c..1a96102 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -62,7 +62,7 @@ class BLOOMSharded(BLOOM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 @@ -122,18 +122,11 @@ class BLOOMSharded(BLOOM): slice_ = f.get_slice(name) if isinstance(module, TensorParallelColumnLinear): - if param_name == "weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": size = slice_.get_shape()[1] diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index f7fbb2a..396cc4f 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -19,8 +19,9 @@ from transformers.models.opt.parallel_layers import ( ) from text_generation_server.models import CausalLM -from text_generation_server.pb import generate_pb2 from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.opt import OPT from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, @@ -158,7 +159,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ) -class Galactica(CausalLM): +class Galactica(OPT): @property def batch_type(self) -> Type[CausalLMBatch]: return GalacticaCausalLMBatch @@ -192,7 +193,7 @@ class GalacticaSharded(Galactica): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 @@ -253,18 +254,11 @@ class GalacticaSharded(Galactica): slice_ = f.get_slice(name) if isinstance(module, TensorParallelColumnLinear): - if param_name == "weight": - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] - else: - size = slice_.get_shape()[0] - block_size = size // world_size - start = rank * block_size - stop = (rank + 1) * block_size - tensor = slice_[start:stop] + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": size = slice_.get_shape()[1] diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index b81976d..fb109ed 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py new file mode 100644 index 0000000..85f0ac8 --- /dev/null +++ b/server/text_generation_server/models/opt.py @@ -0,0 +1,224 @@ +import torch +import torch.distributed + +from typing import List, Optional, Tuple + +from accelerate import init_empty_weights +from safetensors import safe_open +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + AutoConfig, +) +from transformers.models.opt.parallel_layers import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, +) + +from text_generation_server.models import CausalLM +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, +) + +HAS_BITS_AND_BYTES = True +try: + import bitsandbytes as bnb + from bitsandbytes.nn import Int8Params +except Exception as e: + HAS_BITS_AND_BYTES = False + + +class OPT(CausalLM): + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """Overwrite forward to ignore position_ids""" + + # Model Forward + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + return outputs.logits, outputs.past_key_values + + +class OPTSharded(OPT): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + else: + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left", truncation_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + tokenizer.pad_token_id = config.pad_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(CausalLM, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + if name == "lm_head.weight": + continue + + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + current_tensor = parameters[name] + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + tensor = slice_[:] + + if current_tensor.shape != tensor.shape: + raise ValueError( + f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + if ( + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" + ): + tensor = Int8Params( + tensor, + has_fp16_weights=False, + requires_grad=False, + ).to(device) + state = bnb.MatmulLtState() + state.threshold = 6.0 + state.has_fp16_weights = False + state.memory_efficient_backward = False + state.use_pool = True + state.CB = tensor.CB + state.SCB = tensor.SCB + tensor.CB = None + tensor.SCB = None + + def replace_linear(state): + def linear(input, weight, bias): + out = bnb.matmul( + input, + weight, + state=state, + threshold=state.threshold, + bias=bias, + ) + + if state.CB is not None: + # we converted 8-bit row major to turing/ampere format + # in the first inference pass + # we no longer need the row-major weight + del state.CB + weight.data = state.CxB + + return out + + return linear + + module.linear = replace_linear(state) + + else: + tensor = tensor.to(device) + + module._parameters[param_name] = tensor + if name == "model.decoder.embed_tokens.weight": + model.lm_head._parameters["weight"] = tensor + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ): + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + + # Logits are sharded, so we need to gather them + logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] + torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) + logits = torch.cat(logits, dim=2) + + return logits, outputs.past_key_values diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 300b376..5266eb8 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index d338fb2..4feec8a 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -50,7 +50,6 @@ def try_to_load_from_cache( refs_dir = repo_cache / "refs" snapshots_dir = repo_cache / "snapshots" - no_exist_dir = repo_cache / ".no_exist" # Resolve refs (for instance to convert main to the associated commit sha) if refs_dir.is_dir(): @@ -59,10 +58,6 @@ def try_to_load_from_cache( with revision_file.open() as f: revision = f.read() - # Check if file is cached as "no_exist" - if (no_exist_dir / revision / filename).is_file(): - return None - # Check if revision folder exists if not snapshots_dir.exists(): return None