hf_text-generation-inference/server/text_generation_server/models/flash_mistral.py

532 lines
19 KiB
Python
Raw Normal View History

2023-09-28 01:55:47 -06:00
import math
import torch
import torch.distributed
import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers.models.llama import LlamaTokenizerFast
2023-12-11 06:43:40 -07:00
from typing import Optional, Tuple, Type, List
2023-09-28 01:55:47 -06:00
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
2023-12-11 04:46:30 -07:00
from text_generation_server.utils.speculate import get_speculate
2023-09-28 01:55:47 -06:00
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
HeterogeneousNextTokenChooser,
StoppingCriteria,
)
tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL = torch.cuda.graph_pool_handle()
2023-09-28 01:55:47 -06:00
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor] = None
@classmethod
def from_pb(
2023-12-11 06:49:52 -07:00
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
2023-09-28 01:55:47 -06:00
) -> "FlashCausalLMBatch":
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
position_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
prefill_cache_indices = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
2023-12-11 06:49:52 -07:00
zip(pb.requests, batch_tokenized_inputs)
2023-09-28 01:55:47 -06:00
):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
2023-12-11 06:49:52 -07:00
tokenized_input = tokenized_input[-r.truncate :]
2023-09-28 01:55:47 -06:00
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
2023-12-11 04:46:30 -07:00
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
2023-09-28 01:55:47 -06:00
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
2023-12-12 09:55:03 -07:00
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
if SLIDING_WINDOW_BLOCKS is not None:
needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS)
2023-09-28 01:55:47 -06:00
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
2023-12-12 09:55:03 -07:00
if SLIDING_WINDOW is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - SLIDING_WINDOW),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
2023-09-28 01:55:47 -06:00
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
2023-12-11 06:49:52 -07:00
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
2023-09-28 01:55:47 -06:00
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
2023-12-12 09:55:03 -07:00
if SLIDING_WINDOW is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
2023-09-28 01:55:47 -06:00
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
2023-12-12 09:55:03 -07:00
if SLIDING_WINDOW is not None:
prefill_cache_indices = prefill_cache_indices[0]
2023-09-28 01:55:47 -06:00
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
2023-12-12 09:55:03 -07:00
prefill_cache_indices = (
prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None
)
2023-09-28 01:55:47 -06:00
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
2023-12-11 06:49:52 -07:00
speculative_ids=None,
2023-09-28 01:55:47 -06:00
)
2023-12-11 06:43:40 -07:00
class BaseFlashMistral(FlashCausalLM):
2023-09-28 01:55:47 -06:00
def __init__(
2023-12-11 06:49:52 -07:00
self,
config_cls,
model_cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
2023-09-28 01:55:47 -06:00
):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
2023-12-11 06:43:40 -07:00
config = config_cls.from_pretrained(
2023-09-28 01:55:47 -06:00
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
# Set context windows
2023-12-12 09:55:03 -07:00
if config.sliding_window is not None:
SLIDING_WINDOW = config.sliding_window
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
2023-09-28 01:55:47 -06:00
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
2023-12-14 03:02:16 -07:00
weights._set_gptq_params(model_id, revision)
2023-09-28 01:55:47 -06:00
2023-12-11 06:43:40 -07:00
model = model_cls(config, weights)
2023-09-28 01:55:47 -06:00
self.cuda_graphs = {}
2023-09-28 01:55:47 -06:00
torch.distributed.barrier(group=self.process_group)
2023-12-11 06:43:40 -07:00
super(BaseFlashMistral, self).__init__(
2023-09-28 01:55:47 -06:00
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)
@property
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
2023-09-28 01:55:47 -06:00
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
2023-12-11 04:46:30 -07:00
if batch.speculative_ids is not None:
2023-12-11 06:49:52 -07:00
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
2023-12-11 04:46:30 -07:00
speculative_ids = batch.speculative_ids
2023-12-11 06:49:52 -07:00
B, speculative_length = speculative_ids.shape
2023-12-11 04:46:30 -07:00
new_length = speculative_length + 1
2023-12-11 06:49:52 -07:00
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
2023-12-11 04:46:30 -07:00
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
2023-12-11 06:49:52 -07:00
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
2023-12-11 04:46:30 -07:00
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
2023-12-11 06:49:52 -07:00
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
2023-12-11 04:46:30 -07:00
# Add Copy the block tables for all members
2023-12-11 06:49:52 -07:00
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
2023-12-11 04:46:30 -07:00
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
2023-12-11 06:49:52 -07:00
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if self.model.max_past is not None:
max_s = min(self.model.max_past, max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
2023-12-11 06:43:40 -07:00
class FlashMistral(BaseFlashMistral):
def __init__(
2023-12-11 06:49:52 -07:00
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
2023-12-11 06:43:40 -07:00
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
dtype=dtype,
2023-12-11 06:49:52 -07:00
trust_remote_code=trust_remote_code,
2023-12-11 06:43:40 -07:00
)