import math import torch import torch.distributed import numpy as np from dataclasses import dataclass from opentelemetry import trace from transformers import PreTrainedTokenizerBase from transformers.models.llama import LlamaTokenizerFast from typing import Optional, Tuple, Type, List from text_generation_server.pb import generate_pb2 from text_generation_server.models import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.cache_manager import ( get_cache_manager, ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, MistralConfig, ) from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, HeterogeneousNextTokenChooser, StoppingCriteria, ) tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None # Adds windowing logic to FlashCausalLMBatch @dataclass class FlashMistralBatch(FlashCausalLMBatch): # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers # as we only keep SLIDING_WINDOW values instead of the whole tensor prefill_cache_indices: Optional[torch.Tensor] = None @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS batch_inputs = [] max_truncation = 0 for r in pb.requests: batch_inputs.append(r.inputs) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] position_ids = [] cu_seqlen_prefill = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] prefill_cache_indices = [] input_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] # Cumulative length cumulative_length = 0 cumulative_max_length = 0 prefill_out_cumulative_length = 0 blocks = 0 max_seqlen = 0 max_length = 0 max_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i tokenized_input = tokenized_input[-r.truncate :] input_length = len(tokenized_input) input_lengths.append(input_length) prefix_offsets.append(input_length - 5) read_offsets.append(input_length) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange(0, input_length, dtype=torch.int32) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() total_tokens = input_length + max_new_tokens - 1 + speculative_length # Needed blocks can not go over SLIDING_WINDOW_BLOCKS needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) if SLIDING_WINDOW_BLOCKS is not None: needed_blocks = min(needed_blocks, SLIDING_WINDOW_BLOCKS) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill if SLIDING_WINDOW is not None: request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - SLIDING_WINDOW), cumulative_length + input_length, dtype=torch.int64, ) prefill_cache_indices.append(request_prefill_cache_indices) all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: prefill_head_indices.append( torch.tensor( [cumulative_length + input_length - 1], dtype=torch.int32 ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 # Update cumulative_length += input_length cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) max_length = max( max_length, input_length + max_new_tokens + speculative_length ) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device ) start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( (len(all_input_ids), max_length), dtype=np.int64 ) for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) if SLIDING_WINDOW is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] if SLIDING_WINDOW is not None: prefill_cache_indices = prefill_cache_indices[0] cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) prefill_cache_indices = ( prefill_cache_indices.to(device) if SLIDING_WINDOW is not None else None ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int32, device=device ) if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( torch.cat(prefill_head_indices), dtype=torch.int64, device=device ) prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, block_tables=None, block_tables_tensor=None, slots=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, prefill_cache_indices=prefill_cache_indices, speculative_ids=None, ) class BaseFlashMistral(FlashCausalLM): def __init__( self, config_cls, model_cls, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") tokenizer = LlamaTokenizerFast.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left", trust_remote_code=trust_remote_code, ) config = config_cls.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize # Set context windows if config.sliding_window is not None: SLIDING_WINDOW = config.sliding_window SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) model = model_cls(config, weights) torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, dtype=dtype, device=device, rank=rank, world_size=world_size, sliding_window=config.sliding_window, ) @property def batch_type(self) -> Type[FlashMistralBatch]: return FlashMistralBatch def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = get_cache_manager().kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids B, speculative_length = speculative_ids.shape new_length = speculative_length + 1 new_input_ids = torch.cat( [input_ids.unsqueeze(-1), speculative_ids], dim=1 ).reshape(-1) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange_int = arange.to(dtype=torch.int32) new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) # Add Copy the block tables for all members block_tables = ( block_tables.unsqueeze(1) .expand(B, new_length, -1) .reshape(B * new_length, -1) .contiguous() ) max_s = max_s + speculative_length input_ids = new_input_ids position_ids = new_position_ids else: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = get_cache_manager().kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits class FlashMistral(BaseFlashMistral): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): super(FlashMistral, self).__init__( config_cls=MistralConfig, model_cls=FlashMistralForCausalLM, model_id=model_id, revision=revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, )