From 05e9a796cc553b44608456bc3409d52515e56093 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 24 Mar 2023 14:02:14 +0100 Subject: [PATCH] feat(server): flash neoX (#133) --- .github/workflows/build.yaml | 4 + .github/workflows/tests.yaml | 4 + Dockerfile | 9 +- server/Makefile | 17 +- .../text_generation_server/models/__init__.py | 20 +- .../models/causal_lm.py | 1 - .../models/flash_neox.py | 601 +++++++++++++++++ .../models/flash_neox_modeling.py | 637 ++++++++++++++++++ server/text_generation_server/utils/tokens.py | 2 +- .../text_generation_server/utils/watermark.py | 37 +- 10 files changed, 1307 insertions(+), 25 deletions(-) create mode 100644 server/text_generation_server/models/flash_neox.py create mode 100644 server/text_generation_server/models/flash_neox_modeling.py diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 220b2fa3..56015177 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,6 +20,10 @@ on: branches: - 'main' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: build-and-push-image: runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 1a45ad04..f96c53fb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -11,6 +11,10 @@ on: - "Cargo.lock" - "rust-toolchain.toml" +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: run_tests: runs-on: ubuntu-20.04 diff --git a/Dockerfile b/Dockerfile index 5fbf8985..592f1f72 100644 --- a/Dockerfile +++ b/Dockerfile @@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \ CONDA_DEFAULT_ENV=text-generation \ PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin -RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/* +RUN apt-get update && apt-get install -y git curl libssl-dev && rm -rf /var/lib/apt/lists/* RUN cd ~ && \ curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ @@ -53,10 +53,13 @@ RUN cd ~ && \ WORKDIR /usr/src +# Install torch +RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir + COPY server/Makefile server/Makefile -# Install specific version of torch -RUN cd server && make install-torch +# Install specific version of flash attention +RUN cd server && make install-flash-attention # Install specific version of transformers RUN cd server && BUILD_EXTENSIONS="True" make install-transformers diff --git a/server/Makefile b/server/Makefile index e8b0364e..69ef9bc5 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef +flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1 gen-server: # Compile protos @@ -12,13 +13,19 @@ install-transformers: # Install specific version of transformers with custom cuda kernels pip uninstall transformers -y || true rm -rf transformers || true - rm -rf transformers-$(transformers_commit) || true - curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip - unzip $(transformers_commit).zip - rm $(transformers_commit).zip - mv transformers-$(transformers_commit) transformers + git clone https://github.com/OlivierDehaene/transformers.git + cd transformers && git checkout $(transformers_commit) cd transformers && python setup.py install +install-flash-attention: + # Install specific version of flash attention + pip install packaging + pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true + rm -rf flash-attention || true + git clone https://github.com/HazyResearch/flash-attention.git + cd flash-attention && git checkout $(flash_att_commit) + cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install + install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3e2f5c66..2f637ae1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,5 +1,7 @@ +import os import torch +from loguru import logger from transformers import AutoConfig from typing import Optional @@ -12,6 +14,14 @@ from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded +try: + from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1 +except ImportError: + if int(os.environ.get("FLASH_NEOX", 0)) == 1: + logger.exception("Could not import FlashNeoX") + FLASH_NEOX = False + __all__ = [ "Model", "BLOOM", @@ -26,6 +36,10 @@ __all__ = [ "get_model", ] +if FLASH_NEOX: + __all__.append(FlashNeoX) + __all__.append(FlashNeoXSharded) + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True @@ -59,9 +73,11 @@ def get_model( if config.model_type == "gpt_neox": if sharded: - return GPTNeoxSharded(model_id, revision, quantize=quantize) + neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded + return neox_cls(model_id, revision, quantize=quantize) else: - return CausalLM(model_id, revision, quantize=quantize) + neox_cls = FlashNeoX if FLASH_NEOX else CausalLM + return neox_cls(model_id, revision, quantize=quantize) if config.model_type == "t5": if sharded: diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 88ea6c75..c2ad0587 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -64,7 +64,6 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - input_lengths = [] # Parse batch padding_right_offset = 0 diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py new file mode 100644 index 00000000..7be4708b --- /dev/null +++ b/server/text_generation_server/models/flash_neox.py @@ -0,0 +1,601 @@ +import torch +import torch.distributed + +from accelerate import init_empty_weights +from dataclasses import dataclass +from opentelemetry import trace +from safetensors import safe_open +from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig +from typing import Optional, Tuple, List, Type, Union + +from text_generation_server.models import Model +from text_generation_server.models.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + NextTokenChooser, + StoppingCriteria, + Sampling, + initialize_torch_distributed, + weight_files, +) + +tracer = trace.get_tracer(__name__) + + +@dataclass +class FlashNeoXBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + + # Decoder values + input_ids: torch.Tensor + position_ids: torch.Tensor + # cumulative sequence lengths + cu_seqlens: torch.Tensor + max_seqlen: int + past_key_values: Optional[torch.Tensor] + + # All tokens + all_input_ids: List[List[int]] + + # Lengths of all generations present in the batch + input_lengths: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + + def to_pb(self) -> generate_pb2.Batch: + return generate_pb2.Batch( + id=self.batch_id, requests=self.requests, size=len(self) + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, + ) -> "CausalLMBatch": + input_ids = [] + position_ids = [] + cu_seqlens = [0] + max_seqlen = 0 + + input_lengths = [] + all_input_ids = [] + + next_token_choosers = [] + stopping_criterias = [] + + # Cumulative length + cumulative_length = 0 + + # Parse batch + for r in pb.requests: + tokenized_input = tokenizer(r.inputs, return_tensors="pt")[ + "input_ids" + ].squeeze(0) + input_ids.append(tokenized_input) + all_input_ids.append(tokenized_input.tolist()) + + input_length = len(tokenized_input) + max_seqlen = max(max_seqlen, input_length) + input_lengths.append(input_length) + + # Position ids + position_ids.append(torch.arange(0, input_length, dtype=torch.int32)) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(cumulative_length + input_length) + + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criterias.append( + StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) + ) + + # Update + cumulative_length += input_length + + input_ids = torch.concat(input_ids).unsqueeze(1) + position_ids = torch.concat(position_ids) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32) + + return cls( + batch_id=pb.id, + requests=pb.requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=None, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + @classmethod + @tracer.start_as_current_span("concatenate") + def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": + # Batch attributes + requests = [] + input_lengths = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + + # Batch tensors + input_ids = [] + position_ids = [] + cu_seqlens = [torch.tensor([0], dtype=torch.int32)] + max_seqlen = 0 + past_key_values = [] + + # Cumulative length + cumulative_length = torch.tensor(0) + + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + + # Add cumulative lengths of all previous inputs + cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length) + + input_ids.append(batch.input_ids) + position_ids.append(batch.position_ids) + past_key_values.append(batch.past_key_values) + + max_seqlen = max(max_seqlen, batch.max_seqlen) + + # Update + cumulative_length += batch.cu_seqlens[-1] + + input_ids = torch.concat(input_ids) + position_ids = torch.concat(position_ids) + # Concat on dim=1 as first dim represents the model layers + past_key_values = torch.concat(past_key_values, dim=1) + cu_seqlens = torch.concat(cu_seqlens) + + return FlashNeoXBatch( + batch_id=batches[0].batch_id, + requests=requests, + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + ) + + def __len__(self): + return len(self.requests) + + +class FlashNeoX(Model): + def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashNeoX is only available on GPU") + + if quantize: + raise NotImplementedError("FlashNeoX does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + self.model = ( + FlashGPTNeoXForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + ) + .eval() + .cuda() + ) + tokenizer.pad_token_id = ( + self.model.config.pad_token_id + if self.model.config.pad_token_id is not None + else self.model.config.eos_token_id + ) + + super(FlashNeoX, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @property + def batch_type(self) -> Type[FlashNeoXBatch]: + return FlashNeoXBatch + + def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Model Forward + return self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + @tracer.start_as_current_span("generate_token") + def generate_token( + self, batch: FlashNeoXBatch + ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: + # Better to send to device here to avoid device issues in concatenate + position_ids = batch.position_ids.to(self.device, non_blocking=True) + cu_seqlens = batch.cu_seqlens.to(self.device, non_blocking=True) + input_ids = batch.input_ids.squeeze(1).to(self.device) + + out, present = self.forward( + input_ids, + position_ids, + cu_seqlens, + batch.max_seqlen, + batch.past_key_values, + ) + + # List of indices to cache + next_batch_keep_indices = [] + + # New values for next forward + next_batch_input_ids = [] + next_batch_position_ids = [] + next_batch_cu_seqlens = [0] + next_batch_max_seqlen = 0 + next_batch_past_key_values = [] + next_batch_input_lengths = [] + next_batch_all_input_ids = [] + + # Cumulative length + cumulative_length = 0 + + # Results + generations: List[Generation] = [] + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + if batch.past_key_values is None: + # Prefill mode + # out is of shape [cumulative_sequence_lengths, vocab_size] + logits = out[start_index:end_index] + else: + # Decode mode + # out is of shape [batch_size, vocab_size] + logits = out[i].unsqueeze(0) + + # Select next token + next_token_id, logprobs = next_token_chooser(all_input_ids, logits) + # Copy to cpu to avoid other copies when indexing and calling .item() + next_token_id = next_token_id.to("cpu", non_blocking=True) + logprobs = logprobs.to("cpu") + + next_token_id_squeezed = next_token_id.squeeze() + next_token_id_item = next_token_id_squeezed.item() + + # Append next token to all tokens + all_input_ids.append(next_token_id_item) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_text = self.decode_token( + next_token_id_item, + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_item, + next_token_text, + ) + + if stop: + # Decode generated tokens + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :] + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + # Keep request in the batch + next_batch_keep_indices.append(i) + generated_text = None + + # Get sequence present + seq_present = present[:, start_index:end_index] + # Pad it for next iter attention + past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) + next_batch_past_key_values.append(past) + + next_batch_input_ids.append(next_token_id) + next_batch_position_ids.append(input_length) + # Cumulative sum + next_batch_cu_seqlens.append( + next_batch_cu_seqlens[-1] + new_input_length + ) + next_batch_input_lengths.append(new_input_length) + next_batch_all_input_ids.append(all_input_ids) + next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) + + # Prefill + if stopping_criteria.current_tokens == 1: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + logprobs.gather( + 1, torch.tensor(all_input_ids[1:]).unsqueeze(1) + ).squeeze(1)[:-1].tolist() + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, prefill_logprobs, prefill_texts + ) + else: + prefill_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + next_token_id_item, + next_token_logprob, + next_token_text, + next_token_id_item in self.all_special_ids, + generated_text, + ) + + generations.append(generation) + cumulative_length += input_length + + # We finished all generations in the batch; there is no next batch + if not next_batch_keep_indices: + return generations, None + + # If we finished at least one generation, we need to evict the indices of the generations that finished + # from the values of the next batch + if len(next_batch_keep_indices) != len(batch): + # Apply indices to requests, token_choosers and stopping_criterias that need to be cached + next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices] + next_batch_next_token_choosers = [ + batch.next_token_choosers[i] for i in next_batch_keep_indices + ] + next_batch_stopping_criterias = [ + batch.stopping_criterias[i] for i in next_batch_keep_indices + ] + else: + next_batch_requests = batch.requests + next_batch_next_token_choosers = batch.next_token_choosers + next_batch_stopping_criterias = batch.stopping_criterias + + # Create final next batch tensors + next_batch_position_ids = torch.tensor( + next_batch_position_ids, dtype=torch.int32 + ) + next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32) + if len(next_batch_keep_indices) > 1: + next_batch_input_ids = torch.concat(next_batch_input_ids) + next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1) + else: + next_batch_input_ids = next_batch_input_ids[0] + next_batch_past_key_values = next_batch_past_key_values[0] + + next_batch = FlashNeoXBatch( + batch_id=batch.batch_id, + requests=next_batch_requests, + input_ids=next_batch_input_ids, + position_ids=next_batch_position_ids, + cu_seqlens=next_batch_cu_seqlens, + max_seqlen=next_batch_max_seqlen, + past_key_values=next_batch_past_key_values, + input_lengths=next_batch_input_lengths, + all_input_ids=next_batch_all_input_ids, + next_token_choosers=next_batch_next_token_choosers, + stopping_criterias=next_batch_stopping_criterias, + ) + return generations, next_batch + + +class FlashNeoXSharded(FlashNeoX): + def __init__( + self, model_id: str, revision: Optional[str] = None, quantize: bool = False + ): + self.process_group, self.rank, self.world_size = initialize_torch_distributed() + self.master = self.rank == 0 + if torch.cuda.is_available(): + device = torch.device(f"cuda:{self.rank}") + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + else: + raise NotImplementedError("FlashNeoX is only available on GPU") + + if quantize: + raise NotImplementedError("FlashNeoX does not support quantization") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, revision=revision, padding_side="left" + ) + + config = AutoConfig.from_pretrained( + model_id, revision=revision, tp_parallel=True + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashGPTNeoXForCausalLM(config) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + rank=self.rank, + world_size=self.world_size, + ) + self.model = model.eval().to(dtype) + torch.distributed.barrier(group=self.process_group) + super(FlashNeoX, self).__init__( + tokenizer=tokenizer, + device=device, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous() + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + past_key_values: Optional = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.model.gpt_neox.tp_embeddings: + logits, present = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_s=max_s, + past_key_values=past_key_values, + ) + + # Logits are sharded, so we need to gather them + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + # While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard + else: + return super(FlashNeoXSharded, self).forward( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py new file mode 100644 index 00000000..d67ca3c0 --- /dev/null +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -0,0 +1,637 @@ +import torch +import torch.distributed + +import torch.nn.functional as F + +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.models.gpt_neox import GPTNeoXConfig + +# Flash attention imports +import rotary_emb +import flash_attn_cuda +import dropout_layer_norm + +from flash_attn.layers.rotary import RotaryEmbedding + + +class TensorParallelColumnLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + @staticmethod + def linear(input, weight, bias): + return F.linear(input, weight, bias) + + def forward(self, input): + return self.linear(input, self.weight, self.bias) + + +class TensorParallelRowLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + @staticmethod + def linear(input, weight, bias): + return F.linear(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = self.linear(input, self.weight, self.bias) + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + +class TensorParallelEmbedding(nn.Embedding): + def __init__( + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + + self.original_num_embeddings = num_embeddings + + assert num_embeddings % self.tp_world_size == 0 + block_size = num_embeddings // self.tp_world_size + # inputs in `[min_id, max_id[` are handled by `self` to get embeddings + self.min_id = self.tp_rank * block_size + self.max_id = (self.tp_rank + 1) * block_size + + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Sanity check + if torch.any( + torch.logical_or(0 > input, input >= self.original_num_embeddings) + ): + raise IndexError( + f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" + ) + + # `0` if input is in the correct interval, else `1` + input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) + # translate for [0, self.max_id - self.min_id[ + input = input - self.min_id + # default all out of bounds values to `0` + input[input_mask] = 0 + out = super().forward(input) + out[input_mask] = 0.0 + torch.distributed.all_reduce(out, group=self.process_group) + return out + + +class PositionRotaryEmbedding(RotaryEmbedding): + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + """ + Return cos and sin for the asked position ids + """ + + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv + + +class FlashNeoxAttention(torch.nn.Module): + def __init__( + self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + ): + super().__init__() + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + rotary_ndims = int(self.head_size * rotary_pct) + self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = nn.Linear(hidden_size, 3 * hidden_size) + self.dense = nn.Linear(hidden_size, hidden_size) + else: + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + 3 * hidden_size, + process_group=process_group, + ) + self.dense = TensorParallelRowLinear( + hidden_size, + hidden_size, + process_group=process_group, + ) + self.swap_dims = True + + # TODO: remove and swap dims when loading weights + def _swap_dims(self): + """Swap dims for the first inference to avoid an additional permute""" + self.query_key_value.weight = torch.nn.Parameter( + self.query_key_value.weight.view( + self.num_heads, 3, self.head_size, self.hidden_size + ) + .permute(1, 0, 2, 3) + .reshape(-1, self.hidden_size) + ) + self.query_key_value.bias = torch.nn.Parameter( + self.query_key_value.bias.view(self.num_heads, 3, self.head_size) + .permute(1, 0, 2) + .reshape(-1) + ) + self.swap_dims = False + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + if self.swap_dims: + self._swap_dims() + + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(qkv[:, 0]) + # flash attention + flash_attn_cuda.fwd( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + query = qkv_rot[:, 0] + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + layer_past[:, 0], + layer_past[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class FlashMLP(nn.Module): + def __init__(self, act, hidden_size, intermediate_size, process_group=None): + super().__init__() + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu(x, approximate="tanh") + ) + + if process_group is None: + self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size) + self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size) + else: + self.dense_h_to_4h = TensorParallelColumnLinear( + hidden_size, + intermediate_size, + process_group=process_group, + ) + self.dense_4h_to_h = TensorParallelRowLinear( + intermediate_size, + hidden_size, + process_group=process_group, + ) + self.heuristic = "auto" + self.process_group = process_group + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class FlashNeoXLayer(nn.Module): + def __init__( + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, + ): + super().__init__() + self.use_parallel_residual = use_parallel_residual + self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.attention = FlashNeoxAttention( + num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group + ) + self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + if self.use_parallel_residual: + # faster input layer norm + ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.input_layernorm.weight, + self.input_layernorm.bias, + None, + None, + None, + None, + 0.0, + self.input_layernorm.eps, + 1.0, + 0, + None, + False, + False, + ) + + attn_output = self.attention( + ln1_hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + # faster post attention layer norm + ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.post_attention_layernorm.weight, + self.post_attention_layernorm.bias, + None, + None, + None, + None, + 0.0, + self.post_attention_layernorm.eps, + 1.0, + 0, + None, + False, + False, + ) + + mlp_output = self.mlp(ln2_hidden_states) + return mlp_output + attn_output + hidden_states, None + else: + # faster input layer norm + hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.input_layernorm.weight, + self.input_layernorm.bias, + None, + None, + None, + None, + 0.0, + self.input_layernorm.eps, + 1.0, + 0, + None, + False, + False, + ) + + hidden_states = self.attention( + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + # faster post attention layer norm + hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.post_attention_layernorm.weight, + self.post_attention_layernorm.bias, + None, + None, + None, + None, + 0.0, + self.post_attention_layernorm.eps, + 1.0, + 0, + None, + False, + False, + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + + +class FlashGPTNeoXPreTrainedModel(PreTrainedModel): + config_class = GPTNeoXConfig + base_model_prefix = "gpt_neox" + supports_gradient_checkpointing = False + _no_split_modules = None + + +class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.embed_in = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) + + self.layers = nn.ModuleList( + [ + FlashNeoXLayer( + config.num_attention_heads, + config.hidden_act, + config.hidden_size, + config.intermediate_size, + config.rotary_pct, + config.rotary_emb_base, + config.layer_norm_eps, + config.use_parallel_residual, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].attention.head_size + self.num_heads = self.layers[0].attention.num_heads + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states = self.embed_in(input_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.layers), + len(hidden_states), + 2, + self.num_heads, + self.head_size, + ) + ) + layer_past_present_indices = None + cu_seqlens_q = None + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + cu_seqlens_q = torch.arange( + len(cu_seqlens), dtype=torch.int32, device=hidden_states.device + ) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + past_key_values[i], + layer_past_present_indices, + cu_seqlens_q, + ) + + # Faster final layer norm + hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.final_layer_norm.weight, + self.final_layer_norm.bias, + None, + None, + None, + None, + 0.0, + self.final_layer_norm.eps, + 1.0, + 0, + None, + False, + False, + ) + + return hidden_states, past_key_values + + +class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if config.tp_parallel: + process_group = torch.distributed.distributed_c10d._get_default_group() + else: + process_group = None + + self.gpt_neox = FlashGPTNeoXModel(config, process_group) + + if self.gpt_neox.tp_embeddings: + self.embed_out = nn.Linear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.embed_out = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, + ): + hidden_states, present = self.gpt_neox( + input_ids, position_ids, cu_seqlens, max_s, past_key_values + ) + return self.embed_out(hidden_states), present diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index c7594644..597fbe7c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -24,7 +24,7 @@ class Sampling: self.seed = seed def __call__(self, logits): - probs = torch.nn.functional.softmax(logits) + probs = torch.nn.functional.softmax(logits, -1) next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) return next_tokens diff --git a/server/text_generation_server/utils/watermark.py b/server/text_generation_server/utils/watermark.py index 8e90a59c..1850561d 100644 --- a/server/text_generation_server/utils/watermark.py +++ b/server/text_generation_server/utils/watermark.py @@ -17,6 +17,7 @@ import os import torch from transformers import LogitsProcessor +from typing import List, Union GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) DELTA = os.getenv("WATERMARK_DELTA", 2.0) @@ -36,23 +37,32 @@ class WatermarkLogitsProcessor(LogitsProcessor): self.rng = torch.Generator(device=device) self.hash_key = hash_key - def _seed_rng(self, input_ids: torch.LongTensor) -> None: - assert ( - input_ids.shape[-1] >= 1 - ), "requires at least a 1 token prefix sequence to seed rng" - prev_token = input_ids[-1].item() + def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): + if isinstance(input_ids, list): + assert ( + len(input_ids) >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1] + else: + input_ids = input_ids[0] + assert len(input_ids) == 1 + assert ( + input_ids.shape[-1] >= 1 + ), "requires at least a 1 token prefix sequence to seed rng" + prev_token = input_ids[-1].item() self.rng.manual_seed(self.hash_key * prev_token) def _get_greenlist_ids( - self, input_ids: torch.LongTensor, max_value: int - ) -> list[int]: + self, + input_ids: Union[List[int], torch.LongTensor], + max_value: int, + device: torch.device, + ) -> List[int]: # seed the rng using the previous tokens/prefix self._seed_rng(input_ids) greenlist_size = int(max_value * self.gamma) - vocab_permutation = torch.randperm( - max_value, device=input_ids.device, generator=self.rng - ) + vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng) greenlist_ids = vocab_permutation[:greenlist_size] return greenlist_ids @@ -73,10 +83,11 @@ class WatermarkLogitsProcessor(LogitsProcessor): return scores def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor + self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor ) -> torch.FloatTensor: - assert len(input_ids) == 1 - greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1]) + greenlist_ids = self._get_greenlist_ids( + input_ids, scores.shape[-1], scores.device + ) green_tokens_mask = self._calc_greenlist_mask( scores=scores, greenlist_token_ids=greenlist_ids )