hf_text-generation-inference/server/text_generation_server/models/flash_neox.py

617 lines
22 KiB
Python
Raw Normal View History

2023-03-24 07:02:14 -06:00
import torch
import torch.distributed
from torch.nn import functional as F
2023-03-24 07:02:14 -06:00
from accelerate import init_empty_weights
from dataclasses import dataclass
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
from typing import Optional, Tuple, List, Type, Union
from text_generation_server.models import Model
from text_generation_server.models.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
initialize_torch_distributed,
weight_files,
)
tracer = trace.get_tracer(__name__)
@dataclass
class FlashNeoXBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
# cumulative sequence lengths
cu_seqlens: torch.Tensor
max_seqlen: int
past_key_values: Optional[torch.Tensor]
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
2023-03-24 07:02:14 -06:00
# Lengths of all generations present in the batch
input_lengths: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
def to_pb(self) -> generate_pb2.Batch:
return generate_pb2.Batch(
id=self.batch_id, requests=self.requests, size=len(self)
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "CausalLMBatch":
input_ids = []
position_ids = []
cu_seqlens = [0]
max_seqlen = 0
input_lengths = []
all_input_ids = []
all_input_ids_tensor = []
2023-03-24 07:02:14 -06:00
next_token_choosers = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
# Parse batch
for r in pb.requests:
tokenized_input = tokenizer(r.inputs)["input_ids"]
2023-03-24 07:02:14 -06:00
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
all_input_ids.append(tokenized_input)
tokenized_input = torch.tensor(tokenized_input, device=device)
input_ids.append(tokenized_input)
2023-03-24 07:02:14 -06:00
# Position ids
position_ids.append(torch.arange(0, input_length, dtype=torch.int32))
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
all_input_ids_tensor.append(
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
2023-03-24 07:02:14 -06:00
)
# Update
cumulative_length += input_length
input_ids = torch.concat(input_ids)
2023-03-24 07:02:14 -06:00
position_ids = torch.concat(position_ids)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
return cls(
batch_id=pb.id,
requests=pb.requests,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
2023-03-24 07:02:14 -06:00
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
# Batch attributes
requests = []
input_lengths = []
all_input_ids = []
all_input_ids_tensor = []
2023-03-24 07:02:14 -06:00
next_token_choosers = []
stopping_criterias = []
# Batch tensors
input_ids = []
position_ids = []
cu_seqlens = [torch.tensor([0], dtype=torch.int32)]
max_seqlen = 0
past_key_values = []
# Cumulative length
cumulative_length = torch.tensor(0)
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
2023-03-24 07:02:14 -06:00
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(batch.cu_seqlens[1:] + cumulative_length)
input_ids.append(batch.input_ids)
position_ids.append(batch.position_ids)
past_key_values.append(batch.past_key_values)
max_seqlen = max(max_seqlen, batch.max_seqlen)
# Update
cumulative_length += batch.cu_seqlens[-1]
input_ids = torch.concat(input_ids)
position_ids = torch.concat(position_ids)
# Concat on dim=1 as first dim represents the model layers
past_key_values = torch.concat(past_key_values, dim=1)
cu_seqlens = torch.concat(cu_seqlens)
return FlashNeoXBatch(
batch_id=batches[0].batch_id,
requests=requests,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
2023-03-24 07:02:14 -06:00
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
)
def __len__(self):
return len(self.requests)
class FlashNeoX(Model):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
self.model = (
FlashGPTNeoXForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
)
.eval()
.cuda()
)
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
)
super(FlashNeoX, self).__init__(
tokenizer=tokenizer,
device=device,
)
@property
def batch_type(self) -> Type[FlashNeoXBatch]:
return FlashNeoXBatch
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, cleanup_tokenization_spaces=False
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashNeoXBatch
) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]:
# Better to send to device here to avoid device issues in concatenate
position_ids = batch.position_ids.to(self.device, non_blocking=True)
cu_seqlens = batch.cu_seqlens.to(self.device)
2023-03-24 07:02:14 -06:00
out, present = self.forward(
batch.input_ids,
2023-03-24 07:02:14 -06:00
position_ids,
cu_seqlens,
batch.max_seqlen,
batch.past_key_values,
)
# List of indices to cache
next_batch_keep_indices = []
# New values for next forward
next_batch_input_ids = []
next_batch_position_ids = []
next_batch_cu_seqlens = [0]
next_batch_max_seqlen = 0
next_batch_past_key_values = []
next_batch_input_lengths = []
next_batch_all_input_ids = []
next_batch_all_input_ids_tensor = []
2023-03-24 07:02:14 -06:00
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
# Zipped iterator
iterator = zip(
batch.requests,
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
2023-03-24 07:02:14 -06:00
)
# For each member of the batch
for i, (
request,
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
2023-03-24 07:02:14 -06:00
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if batch.past_key_values is None:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits = out[start_index:end_index]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].unsqueeze(0)
# Select next token
next_token_id, logprobs = next_token_chooser(
all_input_ids_tensor[None, :input_length], logits
)
2023-03-24 07:02:14 -06:00
next_token_id_squeezed = next_token_id.squeeze()
next_token_id_item = next_token_id_squeezed.item()
# Append next token to all tokens
all_input_ids.append(next_token_id_item)
all_input_ids_tensor[input_length] = next_token_id_item
2023-03-24 07:02:14 -06:00
new_input_length = input_length + 1
# Generated token
next_token_logprob = logprobs[-1, next_token_id_item]
2023-03-24 07:02:14 -06:00
next_token_text = self.decode_token(
next_token_id_item,
)
# Evaluate stopping criteria
stop, reason = stopping_criteria(
next_token_id_item,
next_token_text,
)
if stop:
# Decode generated tokens
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
)
else:
# Keep request in the batch
next_batch_keep_indices.append(i)
generated_text = None
# Get sequence present
seq_present = present[:, start_index:end_index]
# Pad it for next iter attention
past = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1))
next_batch_past_key_values.append(past)
next_batch_input_ids.append(next_token_id)
next_batch_position_ids.append(input_length)
# Cumulative sum
next_batch_cu_seqlens.append(
next_batch_cu_seqlens[-1] + new_input_length
)
next_batch_input_lengths.append(new_input_length)
next_batch_all_input_ids.append(all_input_ids)
next_batch_all_input_ids_tensor.append(all_input_ids_tensor)
2023-03-24 07:02:14 -06:00
next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length)
# Prefill
if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather(
1, all_input_ids_tensor[1:input_length].unsqueeze(1)
2023-03-24 07:02:14 -06:00
).squeeze(1)[:-1].tolist()
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)
else:
prefill_tokens = None
generation = Generation(
request.id,
prefill_tokens,
next_token_id_item,
next_token_logprob,
next_token_text,
next_token_id_item in self.all_special_ids,
generated_text,
)
generations.append(generation)
cumulative_length += input_length
# We finished all generations in the batch; there is no next batch
if not next_batch_keep_indices:
return generations, None
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if len(next_batch_keep_indices) != len(batch):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests = [batch.requests[i] for i in next_batch_keep_indices]
next_batch_next_token_choosers = [
batch.next_token_choosers[i] for i in next_batch_keep_indices
]
next_batch_stopping_criterias = [
batch.stopping_criterias[i] for i in next_batch_keep_indices
]
else:
next_batch_requests = batch.requests
next_batch_next_token_choosers = batch.next_token_choosers
next_batch_stopping_criterias = batch.stopping_criterias
# Create final next batch tensors
next_batch_position_ids = torch.tensor(
next_batch_position_ids, dtype=torch.int32
)
next_batch_cu_seqlens = torch.tensor(next_batch_cu_seqlens, dtype=torch.int32)
if len(next_batch_keep_indices) > 1:
next_batch_input_ids = torch.concat(next_batch_input_ids).squeeze(1)
2023-03-24 07:02:14 -06:00
next_batch_past_key_values = torch.concat(next_batch_past_key_values, dim=1)
else:
next_batch_input_ids = next_batch_input_ids[0].view(1)
2023-03-24 07:02:14 -06:00
next_batch_past_key_values = next_batch_past_key_values[0]
next_batch = FlashNeoXBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
input_ids=next_batch_input_ids,
position_ids=next_batch_position_ids,
cu_seqlens=next_batch_cu_seqlens,
max_seqlen=next_batch_max_seqlen,
past_key_values=next_batch_past_key_values,
input_lengths=next_batch_input_lengths,
all_input_ids=next_batch_all_input_ids,
all_input_ids_tensor=next_batch_all_input_ids_tensor,
2023-03-24 07:02:14 -06:00
next_token_choosers=next_batch_next_token_choosers,
stopping_criterias=next_batch_stopping_criterias,
)
return generations, next_batch
class FlashNeoXSharded(FlashNeoX):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
rank=self.rank,
world_size=self.world_size,
)
model.post_load_weights()
2023-03-24 07:02:14 -06:00
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoX, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.model.gpt_neox.tp_embeddings:
logits, present = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(FlashNeoXSharded, self).forward(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)