feat(server): flash neoX (#133)

This commit is contained in:
OlivierDehaene 2023-03-24 14:02:14 +01:00 committed by GitHub
parent 23e1028822
commit 05e9a796cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1307 additions and 25 deletions

View File

@ -20,6 +20,10 @@ on:
branches: branches:
- 'main' - 'main'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs: jobs:
build-and-push-image: build-and-push-image:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -11,6 +11,10 @@ on:
- "Cargo.lock" - "Cargo.lock"
- "rust-toolchain.toml" - "rust-toolchain.toml"
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs: jobs:
run_tests: run_tests:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04

View File

@ -43,7 +43,7 @@ ENV LANG=C.UTF-8 \
CONDA_DEFAULT_ENV=text-generation \ CONDA_DEFAULT_ENV=text-generation \
PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin PATH=$PATH:/opt/miniconda/envs/text-generation/bin:/opt/miniconda/bin:/usr/local/cuda/bin
RUN apt-get update && apt-get install -y unzip curl libssl-dev && rm -rf /var/lib/apt/lists/* RUN apt-get update && apt-get install -y git curl libssl-dev && rm -rf /var/lib/apt/lists/*
RUN cd ~ && \ RUN cd ~ && \
curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
@ -53,10 +53,13 @@ RUN cd ~ && \
WORKDIR /usr/src WORKDIR /usr/src
# Install torch
RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
COPY server/Makefile server/Makefile COPY server/Makefile server/Makefile
# Install specific version of torch # Install specific version of flash attention
RUN cd server && make install-torch RUN cd server && make install-flash-attention
# Install specific version of transformers # Install specific version of transformers
RUN cd server && BUILD_EXTENSIONS="True" make install-transformers RUN cd server && BUILD_EXTENSIONS="True" make install-transformers

View File

@ -1,4 +1,5 @@
transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef transformers_commit := 2b57aa18da658e7d2f42ef6bd5b56751af582fef
flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1
gen-server: gen-server:
# Compile protos # Compile protos
@ -12,13 +13,19 @@ install-transformers:
# Install specific version of transformers with custom cuda kernels # Install specific version of transformers with custom cuda kernels
pip uninstall transformers -y || true pip uninstall transformers -y || true
rm -rf transformers || true rm -rf transformers || true
rm -rf transformers-$(transformers_commit) || true git clone https://github.com/OlivierDehaene/transformers.git
curl -L -O https://github.com/OlivierDehaene/transformers/archive/$(transformers_commit).zip cd transformers && git checkout $(transformers_commit)
unzip $(transformers_commit).zip
rm $(transformers_commit).zip
mv transformers-$(transformers_commit) transformers
cd transformers && python setup.py install cd transformers && python setup.py install
install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
install-torch: install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

View File

@ -1,5 +1,7 @@
import os
import torch import torch
from loguru import logger
from transformers import AutoConfig from transformers import AutoConfig
from typing import Optional from typing import Optional
@ -12,6 +14,14 @@ from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
try:
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1
except ImportError:
if int(os.environ.get("FLASH_NEOX", 0)) == 1:
logger.exception("Could not import FlashNeoX")
FLASH_NEOX = False
__all__ = [ __all__ = [
"Model", "Model",
"BLOOM", "BLOOM",
@ -26,6 +36,10 @@ __all__ = [
"get_model", "get_model",
] ]
if FLASH_NEOX:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later. # in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -59,9 +73,11 @@ def get_model(
if config.model_type == "gpt_neox": if config.model_type == "gpt_neox":
if sharded: if sharded:
return GPTNeoxSharded(model_id, revision, quantize=quantize) neox_cls = FlashNeoXSharded if FLASH_NEOX else GPTNeoxSharded
return neox_cls(model_id, revision, quantize=quantize)
else: else:
return CausalLM(model_id, revision, quantize=quantize) neox_cls = FlashNeoX if FLASH_NEOX else CausalLM
return neox_cls(model_id, revision, quantize=quantize)
if config.model_type == "t5": if config.model_type == "t5":
if sharded: if sharded:

View File

@ -64,7 +64,6 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = []
# Parse batch # Parse batch
padding_right_offset = 0 padding_right_offset = 0

View File

@ -0,0 +1,601 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from dataclasses import dataclass
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Union
from text_generation_server.models import Model
from text_generation_server.models.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
initialize_torch_distributed,
weight_files,
)
tracer = trace.get_tracer(__name__)
@dataclass
class FlashNeoXBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
# cumulative sequence lengths
cu_seqlens: torch.Tensor
max_seqlen: int
past_key_values: Optional[torch.Tensor]
# All tokens
all_input_ids: List[List[int]]
# Lengths of all generations present in the batch
input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self)
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
# Parse batch
for r in pb.requests:
tokenized_input = tokenizer(r.inputs, return_tensors="pt")[
"input_ids"
].squeeze(0)
input_ids.append(tokenized_input)
all_input_ids.append(tokenized_input.tolist())
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
# Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criterias.append(
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
)
# Update
cumulative_length += input_length
input_ids = torch.concat(input_ids).unsqueeze(1)
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Batch attributes
requests = []
input_lengths = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = []
position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)]
max_seqlen = 0
past_key_values = []
# Cumulative length
cumulative_length = torch.tensor(0)
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update
cumulative_length += batch.cu_seqlens[-1]
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
return FlashNeoXBatch(
batch_id=batches[0].batch_id,
requests=requests,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
def __len__(self):
return len(self.requests)
class FlashNeoX(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
self.model = (
FlashGPTNeoXForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
)
.eval()
.cuda()
)
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
)
super(FlashNeoX, self).__init__(
tokenizer=tokenizer,
device=device,
)
@property
def batch_type(self) -> Type[FlashNeoXBatch]:
return FlashNeoXBatch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashNeoXBatch
) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]:
# Better to send to device here to avoid device issues in concatenate
position_ids = batch.position_ids.to(self.device, non_blocking=True)
cu_seqlens = batch.cu_seqlens.to(self.device, non_blocking=True)
input_ids = batch.input_ids.squeeze(1).to(self.device)
out, present = self.forward(
input_ids,
position_ids,
cu_seqlens,
batch.max_seqlen,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_all_input_ids = []
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
# For each member of the batch
for i, (
request,
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if batch.past_key_values is None:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].unsqueeze(0)
# Select next token
next_token_id, logprobs = next_token_chooser(all_input_ids, logits)
# Copy to cpu to avoid other copies when indexing and calling .item()
next_token_id = next_token_id.to("cpu", non_blocking=True)
logprobs = logprobs.to("cpu")
next_token_id_squeezed = next_token_id.squeeze()
next_token_id_item = next_token_id_squeezed.item()
# Append next token to all tokens
all_input_ids.append(next_token_id_item)
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_text = self.decode_token(
next_token_id_item,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_item,
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
next_batch_keep_indices.append(i)
generated_text = None
# Get sequence present
seq_present = present[:, start_index:end_index]
# Pad it for next iter attention
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
next_batch_past_key_values.append(past)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length
)
next_batch_input_lengths.append(new_input_length)
next_batch_all_input_ids.append(all_input_ids)
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, torch.tensor(all_input_ids[1:]).unsqueeze(1)
).squeeze(1)[:-1].tolist()
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_item,
next_token_logprob,
next_token_text,
next_token_id_item in self.all_special_ids,
generated_text,
)
generations.append(generation)
cumulative_length += input_length
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids)
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0]
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch
class FlashNeoXSharded(FlashNeoX):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
rank=self.rank,
world_size=self.world_size,
)
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoX, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.model.gpt_neox.tp_embeddings:
logits, present = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(FlashNeoXSharded, self).forward(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)

View File

@ -0,0 +1,637 @@
import torch
import torch.distributed
import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
# Flash attention imports
import rotary_emb
import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
class TensorParallelColumnLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
@staticmethod
def linear(input, weight, bias):
return F.linear(input, weight, bias)
def forward(self, input):
return self.linear(input, self.weight, self.bias)
class TensorParallelRowLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
@staticmethod
def linear(input, weight, bias):
return F.linear(input, weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.linear(input, self.weight, self.bias)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# Sanity check
if torch.any(
torch.logical_or(0 > input, input >= self.original_num_embeddings)
):
raise IndexError(
f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}"
)
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
# translate for [0, self.max_id - self.min_id[
input = input - self.min_id
# default all out of bounds values to `0`
input[input_mask] = 0
out = super().forward(input)
out[input_mask] = 0.0
torch.distributed.all_reduce(out, group=self.process_group)
return out
class PositionRotaryEmbedding(RotaryEmbedding):
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
"""
Return cos and sin for the asked position ids
"""
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
):
super().__init__()
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = nn.Linear(hidden_size, 3 * hidden_size)
self.dense = nn.Linear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
process_group=process_group,
)
self.swap_dims = True
# TODO: remove and swap dims when loading weights
def _swap_dims(self):
"""Swap dims for the first inference to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
)
self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2)
.reshape(-1)
)
self.swap_dims = False
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
if self.swap_dims:
self._swap_dims()
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv_rot[:, 1:]
# output
attn_output = torch.empty_like(qkv[:, 0])
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
query = qkv_rot[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
1,
max_s,
0.0,
self.softmax_scale,
False,
False,
False,
0,
None,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
super().__init__()
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
if process_group is None:
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size)
self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
)
self.heuristic = "auto"
self.process_group = process_group
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class FlashNeoXLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention(
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
)
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
if self.use_parallel_residual:
# faster input layer norm
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.input_layernorm.weight,
self.input_layernorm.bias,
None,
None,
None,
None,
0.0,
self.input_layernorm.eps,
1.0,
0,
None,
False,
False,
)
attn_output = self.attention(
ln1_hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
# faster post attention layer norm
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.eps,
1.0,
0,
None,
False,
False,
)
mlp_output = self.mlp(ln2_hidden_states)
return mlp_output + attn_output + hidden_states, None
else:
# faster input layer norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.input_layernorm.weight,
self.input_layernorm.bias,
None,
None,
None,
None,
0.0,
self.input_layernorm.eps,
1.0,
0,
None,
False,
False,
)
hidden_states = self.attention(
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
# faster post attention layer norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.eps,
1.0,
0,
None,
False,
False,
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
config_class = GPTNeoXConfig
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = False
_no_split_modules = None
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
FlashNeoXLayer(
config.num_attention_heads,
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states = self.embed_in(input_ids)
# Prefill
if past_key_values is None:
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.layers),
len(hidden_states),
2,
self.num_heads,
self.head_size,
)
)
layer_past_present_indices = None
cu_seqlens_q = None
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
len(cu_seqlens), dtype=torch.int32, device=hidden_states.device
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
past_key_values[i],
layer_past_present_indices,
cu_seqlens_q,
)
# Faster final layer norm
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
None,
None,
None,
None,
0.0,
self.final_layer_norm.eps,
1.0,
0,
None,
False,
False,
)
return hidden_states, past_key_values
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if config.tp_parallel:
process_group = torch.distributed.distributed_c10d._get_default_group()
else:
process_group = None
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = nn.Linear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.embed_out = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
)
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
):
hidden_states, present = self.gpt_neox(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)
return self.embed_out(hidden_states), present

View File

@ -24,7 +24,7 @@ class Sampling:
self.seed = seed self.seed = seed
def __call__(self, logits): def __call__(self, logits):
probs = torch.nn.functional.softmax(logits) probs = torch.nn.functional.softmax(logits, -1)
next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator) next_tokens = torch.multinomial(probs, num_samples=1, generator=self.generator)
return next_tokens return next_tokens

View File

@ -17,6 +17,7 @@ import os
import torch import torch
from transformers import LogitsProcessor from transformers import LogitsProcessor
from typing import List, Union
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5) GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0) DELTA = os.getenv("WATERMARK_DELTA", 2.0)
@ -36,7 +37,15 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self.rng = torch.Generator(device=device) self.rng = torch.Generator(device=device)
self.hash_key = hash_key self.hash_key = hash_key
def _seed_rng(self, input_ids: torch.LongTensor) -> None: def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
if isinstance(input_ids, list):
assert (
len(input_ids) >= 1
), "requires at least a 1 token prefix sequence to seed rng"
prev_token = input_ids[-1]
else:
input_ids = input_ids[0]
assert len(input_ids) == 1
assert ( assert (
input_ids.shape[-1] >= 1 input_ids.shape[-1] >= 1
), "requires at least a 1 token prefix sequence to seed rng" ), "requires at least a 1 token prefix sequence to seed rng"
@ -44,15 +53,16 @@ class WatermarkLogitsProcessor(LogitsProcessor):
self.rng.manual_seed(self.hash_key * prev_token) self.rng.manual_seed(self.hash_key * prev_token)
def _get_greenlist_ids( def _get_greenlist_ids(
self, input_ids: torch.LongTensor, max_value: int self,
) -> list[int]: input_ids: Union[List[int], torch.LongTensor],
max_value: int,
device: torch.device,
) -> List[int]:
# seed the rng using the previous tokens/prefix # seed the rng using the previous tokens/prefix
self._seed_rng(input_ids) self._seed_rng(input_ids)
greenlist_size = int(max_value * self.gamma) greenlist_size = int(max_value * self.gamma)
vocab_permutation = torch.randperm( vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)
max_value, device=input_ids.device, generator=self.rng
)
greenlist_ids = vocab_permutation[:greenlist_size] greenlist_ids = vocab_permutation[:greenlist_size]
return greenlist_ids return greenlist_ids
@ -73,10 +83,11 @@ class WatermarkLogitsProcessor(LogitsProcessor):
return scores return scores
def __call__( def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert len(input_ids) == 1 greenlist_ids = self._get_greenlist_ids(
greenlist_ids = self._get_greenlist_ids(input_ids[0], scores.shape[-1]) input_ids, scores.shape[-1], scores.device
)
green_tokens_mask = self._calc_greenlist_mask( green_tokens_mask = self._calc_greenlist_mask(
scores=scores, greenlist_token_ids=greenlist_ids scores=scores, greenlist_token_ids=greenlist_ids
) )