2023-09-28 01:55:47 -06:00
|
|
|
import math
|
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from opentelemetry import trace
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from transformers.models.llama import LlamaTokenizerFast
|
2024-02-28 04:07:08 -07:00
|
|
|
from typing import Optional, Tuple, Type
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
from text_generation_server.pb import generate_pb2
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
|
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
|
|
|
from text_generation_server.models.cache_manager import (
|
|
|
|
get_cache_manager,
|
|
|
|
)
|
|
|
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
|
|
|
FlashMistralForCausalLM,
|
|
|
|
MistralConfig,
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
from text_generation_server.utils.speculate import get_speculate
|
2023-09-28 01:55:47 -06:00
|
|
|
from text_generation_server.utils import (
|
|
|
|
initialize_torch_distributed,
|
|
|
|
weight_files,
|
|
|
|
Weights,
|
|
|
|
HeterogeneousNextTokenChooser,
|
|
|
|
StoppingCriteria,
|
|
|
|
)
|
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
# Will be set in init
|
|
|
|
SLIDING_WINDOW: Optional[int] = None
|
|
|
|
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
|
|
|
|
2024-02-12 02:09:29 -07:00
|
|
|
MEM_POOL = torch.cuda.graph_pool_handle()
|
|
|
|
|
2023-09-28 01:55:47 -06:00
|
|
|
|
2024-02-28 04:07:08 -07:00
|
|
|
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
|
|
|
global SLIDING_WINDOW
|
|
|
|
global SLIDING_WINDOW_BLOCKS
|
|
|
|
SLIDING_WINDOW = sliding_window
|
|
|
|
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
|
|
|
|
|
|
|
|
|
|
|
|
def get_sliding_windows() -> Tuple[int, int]:
|
|
|
|
global SLIDING_WINDOW
|
|
|
|
global SLIDING_WINDOW_BLOCKS
|
|
|
|
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
|
|
|
|
|
|
|
|
|
2023-09-28 01:55:47 -06:00
|
|
|
# Adds windowing logic to FlashCausalLMBatch
|
|
|
|
@dataclass
|
|
|
|
class FlashMistralBatch(FlashCausalLMBatch):
|
|
|
|
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
|
|
|
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
|
|
|
prefill_cache_indices: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
2023-12-11 06:49:52 -07:00
|
|
|
cls,
|
|
|
|
pb: generate_pb2.Batch,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
device: torch.device,
|
2023-09-28 01:55:47 -06:00
|
|
|
) -> "FlashCausalLMBatch":
|
2024-02-28 04:07:08 -07:00
|
|
|
sliding_window, sliding_window_blocks = get_sliding_windows()
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
batch_inputs = []
|
|
|
|
max_truncation = 0
|
|
|
|
for r in pb.requests:
|
|
|
|
batch_inputs.append(r.inputs)
|
|
|
|
max_truncation = max(max_truncation, r.truncate)
|
|
|
|
|
|
|
|
batch_tokenized_inputs = tokenizer(
|
|
|
|
batch_inputs, truncation=True, max_length=max_truncation
|
|
|
|
)["input_ids"]
|
|
|
|
|
|
|
|
position_ids = []
|
|
|
|
cu_seqlen_prefill = [0]
|
|
|
|
needed_blocks_slots = []
|
|
|
|
start_slots = []
|
|
|
|
slot_indices = []
|
|
|
|
prefill_cache_indices = []
|
|
|
|
|
|
|
|
input_lengths = []
|
|
|
|
prefix_offsets = []
|
|
|
|
read_offsets = []
|
|
|
|
all_input_ids = []
|
|
|
|
requests_idx_mapping = {}
|
|
|
|
|
|
|
|
all_prefill_logprobs = True
|
|
|
|
no_prefill_logprobs = True
|
|
|
|
prefill_head_indices = []
|
|
|
|
prefill_next_token_indices = []
|
|
|
|
prefill_cu_outlens = [0]
|
|
|
|
|
|
|
|
next_token_chooser_parameters = []
|
|
|
|
stopping_criterias = []
|
|
|
|
top_n_tokens = []
|
|
|
|
|
|
|
|
# Cumulative length
|
|
|
|
cumulative_length = 0
|
|
|
|
cumulative_max_length = 0
|
|
|
|
prefill_out_cumulative_length = 0
|
|
|
|
|
|
|
|
blocks = 0
|
|
|
|
max_seqlen = 0
|
|
|
|
max_length = 0
|
|
|
|
max_blocks = 0
|
|
|
|
|
|
|
|
# Parse batch
|
|
|
|
for i, (r, tokenized_input) in enumerate(
|
2023-12-11 06:49:52 -07:00
|
|
|
zip(pb.requests, batch_tokenized_inputs)
|
2023-09-28 01:55:47 -06:00
|
|
|
):
|
|
|
|
# request id -> idx in list mapping
|
|
|
|
requests_idx_mapping[r.id] = i
|
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
tokenized_input = tokenized_input[-r.truncate :]
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
input_length = len(tokenized_input)
|
|
|
|
input_lengths.append(input_length)
|
|
|
|
|
|
|
|
prefix_offsets.append(input_length - 5)
|
|
|
|
read_offsets.append(input_length)
|
|
|
|
|
|
|
|
all_input_ids.append(tokenized_input)
|
|
|
|
|
|
|
|
# Position ids
|
|
|
|
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
|
|
|
position_ids.append(request_position_ids)
|
|
|
|
|
|
|
|
# Add cumulative lengths of all previous inputs
|
|
|
|
cu_seqlen_prefill.append(cumulative_length + input_length)
|
|
|
|
|
|
|
|
next_token_chooser_parameters.append(r.parameters)
|
|
|
|
|
|
|
|
stopping_criteria = StoppingCriteria.from_pb(
|
|
|
|
r.stopping_parameters, tokenizer
|
|
|
|
)
|
|
|
|
max_new_tokens = stopping_criteria.max_new_tokens
|
|
|
|
stopping_criterias.append(stopping_criteria)
|
|
|
|
top_n_tokens.append(r.top_n_tokens)
|
|
|
|
|
|
|
|
# Paged attention
|
|
|
|
# Remove one as the first token des not have a past
|
2023-12-11 04:46:30 -07:00
|
|
|
speculative_length = get_speculate()
|
|
|
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
2023-12-12 09:55:03 -07:00
|
|
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
2024-02-28 04:07:08 -07:00
|
|
|
if sliding_window_blocks is not None:
|
|
|
|
needed_blocks = min(needed_blocks, sliding_window_blocks)
|
2023-09-28 01:55:47 -06:00
|
|
|
blocks += needed_blocks
|
|
|
|
|
|
|
|
needed_blocks_slots.append((needed_blocks, total_tokens))
|
|
|
|
start_slots.append(cumulative_max_length)
|
|
|
|
|
|
|
|
request_slot_indices = torch.arange(
|
|
|
|
cumulative_max_length,
|
|
|
|
cumulative_max_length + input_length,
|
|
|
|
dtype=torch.int64,
|
|
|
|
)
|
|
|
|
slot_indices.append(request_slot_indices)
|
|
|
|
|
|
|
|
# Create tensor to slice into the kv tensor in prefill
|
2024-02-28 04:07:08 -07:00
|
|
|
if sliding_window is not None:
|
2023-12-12 09:55:03 -07:00
|
|
|
request_prefill_cache_indices = torch.arange(
|
2024-02-28 04:07:08 -07:00
|
|
|
cumulative_length + max(0, input_length - sliding_window),
|
2023-12-12 09:55:03 -07:00
|
|
|
cumulative_length + input_length,
|
|
|
|
dtype=torch.int64,
|
|
|
|
)
|
|
|
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
|
|
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
|
|
|
|
|
|
|
if r.prefill_logprobs:
|
|
|
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
|
|
|
prefill_next_token_indices.append(
|
|
|
|
prefill_out_cumulative_length + input_length - 1
|
|
|
|
)
|
|
|
|
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
|
|
|
prefill_out_cumulative_length += input_length
|
|
|
|
else:
|
|
|
|
prefill_head_indices.append(
|
|
|
|
torch.tensor(
|
|
|
|
[cumulative_length + input_length - 1], dtype=torch.int32
|
|
|
|
)
|
|
|
|
)
|
|
|
|
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
|
|
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
|
|
|
prefill_out_cumulative_length += 1
|
|
|
|
|
|
|
|
# Update
|
|
|
|
cumulative_length += input_length
|
|
|
|
cumulative_max_length += total_tokens
|
|
|
|
max_seqlen = max(max_seqlen, input_length)
|
|
|
|
max_blocks = max(max_blocks, needed_blocks)
|
2023-12-11 06:49:52 -07:00
|
|
|
max_length = max(
|
|
|
|
max_length, input_length + max_new_tokens + speculative_length
|
|
|
|
)
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
2024-02-15 02:28:10 -07:00
|
|
|
next_token_chooser_parameters, dtype, device, tokenizer
|
2023-09-28 01:55:47 -06:00
|
|
|
)
|
|
|
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
|
|
|
|
|
|
|
# Padded all_input_ids_tensor
|
|
|
|
all_input_ids_tensor = np.zeros(
|
|
|
|
(len(all_input_ids), max_length), dtype=np.int64
|
|
|
|
)
|
|
|
|
for i, input_ids in enumerate(all_input_ids):
|
|
|
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
|
|
|
|
|
|
|
# Create tensors on device
|
|
|
|
all_input_ids_tensor = torch.tensor(
|
|
|
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
|
|
)
|
|
|
|
|
|
|
|
if len(pb.requests) > 1:
|
|
|
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
|
|
|
position_ids = torch.cat(position_ids)
|
|
|
|
slot_indices = torch.cat(slot_indices)
|
2024-02-28 04:07:08 -07:00
|
|
|
if sliding_window is not None:
|
2023-12-12 09:55:03 -07:00
|
|
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
2023-09-28 01:55:47 -06:00
|
|
|
else:
|
|
|
|
input_ids = all_input_ids[0]
|
|
|
|
position_ids = position_ids[0]
|
|
|
|
slot_indices = slot_indices[0]
|
2024-02-28 04:07:08 -07:00
|
|
|
if sliding_window is not None:
|
2023-12-12 09:55:03 -07:00
|
|
|
prefill_cache_indices = prefill_cache_indices[0]
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
cu_seqlen_prefill = torch.tensor(
|
|
|
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
|
|
|
)
|
|
|
|
|
|
|
|
position_ids = position_ids.to(device)
|
|
|
|
slot_indices = slot_indices.to(device)
|
2023-12-12 09:55:03 -07:00
|
|
|
prefill_cache_indices = (
|
2024-02-28 04:07:08 -07:00
|
|
|
prefill_cache_indices.to(device) if sliding_window is not None else None
|
2023-12-12 09:55:03 -07:00
|
|
|
)
|
2023-09-28 01:55:47 -06:00
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
|
|
|
input_lengths_tensor = torch.tensor(
|
|
|
|
input_lengths, dtype=torch.int32, device=device
|
|
|
|
)
|
|
|
|
|
|
|
|
if all_prefill_logprobs:
|
|
|
|
prefill_head_indices = None
|
|
|
|
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
|
|
|
elif no_prefill_logprobs:
|
|
|
|
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
|
|
|
prefill_next_token_indices = None
|
|
|
|
else:
|
|
|
|
prefill_head_indices = torch.tensor(
|
|
|
|
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
|
|
|
)
|
|
|
|
prefill_next_token_indices = torch.tensor(
|
|
|
|
prefill_next_token_indices, dtype=torch.int64, device=device
|
|
|
|
)
|
|
|
|
top_n_tokens_tensor = torch.tensor(
|
|
|
|
top_n_tokens, device=device, dtype=torch.int64
|
|
|
|
)
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
batch_id=pb.id,
|
|
|
|
requests=pb.requests,
|
|
|
|
requests_idx_mapping=requests_idx_mapping,
|
|
|
|
input_ids=input_ids,
|
|
|
|
position_ids=position_ids,
|
|
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
|
|
start_slots=start_slots,
|
|
|
|
slot_indices=slot_indices,
|
|
|
|
needed_blocks_slots=needed_blocks_slots,
|
|
|
|
block_tables=None,
|
|
|
|
block_tables_tensor=None,
|
|
|
|
slots=None,
|
|
|
|
max_seqlen=max_seqlen,
|
|
|
|
prefill_head_indices=prefill_head_indices,
|
|
|
|
prefill_next_token_indices=prefill_next_token_indices,
|
|
|
|
prefill_cu_outlens=prefill_cu_outlens,
|
|
|
|
input_lengths=input_lengths,
|
|
|
|
input_lengths_tensor=input_lengths_tensor,
|
|
|
|
prefix_offsets=prefix_offsets,
|
|
|
|
read_offsets=read_offsets,
|
|
|
|
all_input_ids=all_input_ids,
|
|
|
|
all_input_ids_tensor=all_input_ids_tensor,
|
|
|
|
next_token_chooser=next_token_chooser,
|
|
|
|
stopping_criterias=stopping_criterias,
|
|
|
|
top_n_tokens=top_n_tokens,
|
|
|
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
|
|
blocks=blocks,
|
|
|
|
max_blocks=max_blocks,
|
|
|
|
prefill_cache_indices=prefill_cache_indices,
|
2023-12-11 06:49:52 -07:00
|
|
|
speculative_ids=None,
|
2023-09-28 01:55:47 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-12-11 06:43:40 -07:00
|
|
|
class BaseFlashMistral(FlashCausalLM):
|
2023-09-28 01:55:47 -06:00
|
|
|
def __init__(
|
2023-12-11 06:49:52 -07:00
|
|
|
self,
|
|
|
|
config_cls,
|
|
|
|
model_cls,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2024-02-26 11:49:28 -07:00
|
|
|
use_medusa: Optional[str] = None,
|
2023-12-11 06:49:52 -07:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
trust_remote_code: bool = False,
|
2023-09-28 01:55:47 -06:00
|
|
|
):
|
|
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device(f"cuda:{rank}")
|
|
|
|
dtype = torch.float16 if dtype is None else dtype
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
|
|
|
|
|
|
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
|
|
|
|
2023-12-11 06:43:40 -07:00
|
|
|
config = config_cls.from_pretrained(
|
2023-09-28 01:55:47 -06:00
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
config.quantize = quantize
|
2024-02-26 11:49:28 -07:00
|
|
|
config.use_medusa = use_medusa
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
# Set context windows
|
2023-12-12 09:55:03 -07:00
|
|
|
if config.sliding_window is not None:
|
2024-02-28 04:07:08 -07:00
|
|
|
set_sliding_window(
|
|
|
|
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
|
|
|
)
|
2023-09-28 01:55:47 -06:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
|
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
|
|
|
if config.quantize in ["gptq", "awq"]:
|
2023-12-14 03:02:16 -07:00
|
|
|
weights._set_gptq_params(model_id, revision)
|
2023-09-28 01:55:47 -06:00
|
|
|
|
2023-12-11 06:43:40 -07:00
|
|
|
model = model_cls(config, weights)
|
2023-09-28 01:55:47 -06:00
|
|
|
|
2024-02-12 02:09:29 -07:00
|
|
|
self.cuda_graphs = {}
|
|
|
|
|
2023-09-28 01:55:47 -06:00
|
|
|
torch.distributed.barrier(group=self.process_group)
|
2023-12-11 06:43:40 -07:00
|
|
|
super(BaseFlashMistral, self).__init__(
|
2023-09-28 01:55:47 -06:00
|
|
|
model=model,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
num_layers=len(model.model.layers),
|
|
|
|
num_kv_heads=model.model.num_key_value_heads,
|
|
|
|
head_size=model.model.head_size,
|
|
|
|
dtype=dtype,
|
|
|
|
device=device,
|
|
|
|
rank=rank,
|
|
|
|
world_size=world_size,
|
|
|
|
sliding_window=config.sliding_window,
|
|
|
|
)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def batch_type(self) -> Type[FlashMistralBatch]:
|
|
|
|
return FlashMistralBatch
|
|
|
|
|
2024-02-12 02:09:29 -07:00
|
|
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
|
|
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
|
|
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
|
|
|
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
|
|
|
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
|
|
|
block_tables = (
|
|
|
|
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
|
|
|
.repeat(bs)
|
|
|
|
.reshape((bs, max_bt))
|
|
|
|
)
|
|
|
|
kv_cache = get_cache_manager().kv_cache
|
|
|
|
|
|
|
|
self.cuda_graphs[bs] = {
|
|
|
|
"input_ids": input_ids,
|
|
|
|
"position_ids": position_ids,
|
|
|
|
"kv_cache": kv_cache,
|
|
|
|
"block_tables": block_tables,
|
|
|
|
"slots": slots,
|
|
|
|
"input_lengths": input_lengths,
|
|
|
|
}
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
self.cuda_graphs[bs]["graph"] = graph
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
# Run once outside to warmup
|
|
|
|
self.model.forward(
|
|
|
|
input_ids=input_ids,
|
|
|
|
position_ids=position_ids,
|
|
|
|
cu_seqlen_prefill=None,
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
block_tables=block_tables,
|
|
|
|
slots=slots,
|
|
|
|
input_lengths=input_lengths,
|
|
|
|
max_s=max_s,
|
|
|
|
prefill_cache_indices=None,
|
|
|
|
lm_head_indices=None,
|
|
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
2024-02-26 11:49:28 -07:00
|
|
|
logits, speculative_logits = self.model.forward(
|
2024-02-12 02:09:29 -07:00
|
|
|
input_ids=input_ids,
|
|
|
|
position_ids=position_ids,
|
|
|
|
cu_seqlen_prefill=None,
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
block_tables=block_tables,
|
|
|
|
slots=slots,
|
|
|
|
input_lengths=input_lengths,
|
|
|
|
max_s=max_s,
|
|
|
|
prefill_cache_indices=None,
|
|
|
|
lm_head_indices=None,
|
|
|
|
)
|
2024-02-26 11:49:28 -07:00
|
|
|
self.cuda_graphs[bs]["logits"] = logits
|
|
|
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
2024-02-12 02:09:29 -07:00
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
2024-02-26 11:49:28 -07:00
|
|
|
def forward(
|
|
|
|
self, batch: FlashMistralBatch
|
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
2023-09-28 01:55:47 -06:00
|
|
|
# Model Forward
|
2023-12-11 04:46:30 -07:00
|
|
|
if batch.speculative_ids is not None:
|
2023-12-11 06:49:52 -07:00
|
|
|
input_ids = batch.input_ids
|
|
|
|
position_ids = batch.position_ids
|
|
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
|
|
kv_cache = get_cache_manager().kv_cache
|
|
|
|
block_tables = batch.block_tables_tensor
|
|
|
|
slots = batch.slots[batch.slot_indices]
|
|
|
|
input_lengths = batch.input_lengths_tensor
|
|
|
|
max_s = batch.max_seqlen
|
|
|
|
lm_head_indices = batch.prefill_head_indices
|
2023-12-11 04:46:30 -07:00
|
|
|
|
|
|
|
speculative_ids = batch.speculative_ids
|
|
|
|
|
2023-12-11 06:49:52 -07:00
|
|
|
B, speculative_length = speculative_ids.shape
|
2023-12-11 04:46:30 -07:00
|
|
|
new_length = speculative_length + 1
|
2023-12-11 06:49:52 -07:00
|
|
|
new_input_ids = torch.cat(
|
|
|
|
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
|
|
).reshape(-1)
|
2023-12-11 04:46:30 -07:00
|
|
|
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
|
|
arange_int = arange.to(dtype=torch.int32)
|
2023-12-11 06:49:52 -07:00
|
|
|
new_position_ids = (
|
|
|
|
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
|
|
).view(-1)
|
2023-12-11 04:46:30 -07:00
|
|
|
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
2023-12-11 06:49:52 -07:00
|
|
|
input_lengths = (
|
|
|
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
|
|
).view(-1)
|
2023-12-11 04:46:30 -07:00
|
|
|
|
|
|
|
# Add Copy the block tables for all members
|
2023-12-11 06:49:52 -07:00
|
|
|
block_tables = (
|
|
|
|
block_tables.unsqueeze(1)
|
|
|
|
.expand(B, new_length, -1)
|
|
|
|
.reshape(B * new_length, -1)
|
|
|
|
.contiguous()
|
|
|
|
)
|
2023-12-11 04:46:30 -07:00
|
|
|
max_s = max_s + speculative_length
|
|
|
|
|
|
|
|
input_ids = new_input_ids
|
|
|
|
position_ids = new_position_ids
|
|
|
|
else:
|
2023-12-11 06:49:52 -07:00
|
|
|
input_ids = batch.input_ids
|
|
|
|
position_ids = batch.position_ids
|
|
|
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
|
|
kv_cache = get_cache_manager().kv_cache
|
|
|
|
block_tables = batch.block_tables_tensor
|
|
|
|
slots = batch.slots[batch.slot_indices]
|
|
|
|
input_lengths = batch.input_lengths_tensor
|
|
|
|
max_s = batch.max_seqlen
|
|
|
|
lm_head_indices = batch.prefill_head_indices
|
2024-02-12 02:09:29 -07:00
|
|
|
|
2024-02-19 07:23:12 -07:00
|
|
|
if cu_seqlen_prefill is None and self.model.max_past is not None:
|
|
|
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
|
|
# in a circular buffer mode.
|
|
|
|
# This makes sure the max_s for the decode pass is correct.
|
2024-02-12 02:09:29 -07:00
|
|
|
max_s = min(self.model.max_past, max_s)
|
|
|
|
|
|
|
|
bs = input_ids.shape[0]
|
|
|
|
padded_bs = bs
|
|
|
|
if bs == 3:
|
|
|
|
padded_bs = 4
|
|
|
|
elif 3 < bs <= 8:
|
|
|
|
padded_bs = 8
|
|
|
|
elif bs > 8:
|
|
|
|
padded_bs = (bs + 7) // 8 * 8
|
|
|
|
|
|
|
|
# Try to find an associated cuda graph
|
|
|
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
|
|
|
|
|
|
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
2024-02-26 11:49:28 -07:00
|
|
|
logits, speculative_logits = self.model.forward(
|
2024-02-12 02:09:29 -07:00
|
|
|
input_ids=input_ids,
|
|
|
|
position_ids=position_ids,
|
|
|
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
block_tables=block_tables,
|
|
|
|
slots=slots,
|
|
|
|
input_lengths=input_lengths,
|
|
|
|
max_s=max_s,
|
|
|
|
prefill_cache_indices=batch.prefill_cache_indices,
|
|
|
|
lm_head_indices=lm_head_indices,
|
|
|
|
)
|
|
|
|
if batch.prefill_cache_indices is not None:
|
|
|
|
batch.prefill_cache_indices = None
|
2024-02-26 11:49:28 -07:00
|
|
|
return logits, speculative_logits
|
2024-02-12 02:09:29 -07:00
|
|
|
|
|
|
|
# Copy inputs to the static inputs of the cuda graph
|
|
|
|
# Static inputs are potentially padded
|
|
|
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
|
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
|
|
cuda_graph["block_tables"][
|
|
|
|
: block_tables.shape[0], : block_tables.shape[1]
|
|
|
|
] = block_tables
|
|
|
|
cuda_graph["slots"].fill_(-1)
|
|
|
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
|
|
cuda_graph["input_lengths"].zero_()
|
|
|
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
|
|
|
|
|
|
# Replay the graph
|
|
|
|
cuda_graph["graph"].replay()
|
|
|
|
|
|
|
|
# Slice output to the correct shape
|
2024-02-26 11:49:28 -07:00
|
|
|
speculative_logits = (
|
|
|
|
cuda_graph["speculative_logits"][:bs]
|
|
|
|
if cuda_graph["speculative_logits"] is not None
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
logits = cuda_graph["logits"][:bs]
|
|
|
|
return logits, speculative_logits
|
2023-12-11 06:43:40 -07:00
|
|
|
|
|
|
|
|
|
|
|
class FlashMistral(BaseFlashMistral):
|
|
|
|
def __init__(
|
2023-12-11 06:49:52 -07:00
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2024-02-26 11:49:28 -07:00
|
|
|
use_medusa: Optional[str] = None,
|
2023-12-11 06:49:52 -07:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
trust_remote_code: bool = False,
|
2023-12-11 06:43:40 -07:00
|
|
|
):
|
|
|
|
super(FlashMistral, self).__init__(
|
|
|
|
config_cls=MistralConfig,
|
|
|
|
model_cls=FlashMistralForCausalLM,
|
|
|
|
model_id=model_id,
|
|
|
|
revision=revision,
|
|
|
|
quantize=quantize,
|
2024-02-26 11:49:28 -07:00
|
|
|
use_medusa=use_medusa,
|
2023-12-11 06:43:40 -07:00
|
|
|
dtype=dtype,
|
2023-12-11 06:49:52 -07:00
|
|
|
trust_remote_code=trust_remote_code,
|
2023-12-11 06:43:40 -07:00
|
|
|
)
|