import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional # Flash attention imports import flash_attn_cuda from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelHead, TensorParallelEmbedding, FastLayerNorm, get_linear, ) def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() world_size = weights.process_group.size() rank = weights.process_group.rank() if config.transpose: block_size = (shape[1] - 2 * head_size) // world_size start = rank * block_size stop = (rank + 1) * block_size assert (shape[1] - 2 * head_size) % world_size == 0 q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] weight = torch.cat([q_tensor, kv_tensor], dim=1).T else: block_size = (shape[0] - 2 * head_size) // world_size start = rank * block_size stop = (rank + 1) * block_size assert (shape[0] - 2 * head_size) % world_size == 0 q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] weight = torch.cat([q_tensor, kv_tensor], dim=0) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") shape = slice_.get_shape() block_size = (shape[0] - 2 * head_size) // world_size assert (shape[0] - 2 * head_size) % world_size == 0 q_tensor = slice_[start:stop] start = rank * block_size stop = (rank + 1) * block_size q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) else: if config.transpose: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, weights.get_tensor(f"{prefix}.kv_attn.weight").T, ] weight = torch.cat(w, dim=0) else: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), weights.get_tensor(f"{prefix}.kv_attn.weight"), ] weight = torch.cat(w, dim=1) if bias: b = [ weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), weights.get_tensor(f"{prefix}.kv_attn.bias"), ] bias = torch.cat(b, dim=0) else: bias = None weight = weight.to(dtype=weights.dtype).to(device=weights.device) assert list(weight.shape) == [ (num_heads + 2) * head_size, hidden_size, ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" if bias is not None: bias = bias.to(dtype=weights.dtype).to(device=weights.device) assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: weight = weights.get_sharded(f"{prefix}.weight", dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return TensorParallelRowLinear( get_linear(weight, bias, config.quantize), process_group=weights.process_group ) class FlashMQAttention(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() num_heads = config.num_attention_heads hidden_size = config.hidden_size self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads assert self.num_heads % weights.process_group.size() == 0 self.num_heads = self.num_heads // weights.process_group.size() self.softmax_scale = self.head_size ** (-0.5) self.c_attn = load_multi_mqa( config, prefix=prefix, weights=weights, bias=True, head_size=self.head_size, hidden_size=hidden_size, num_heads=self.num_heads, ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) def forward( self, hidden_states, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q, ): qkv = self.c_attn(hidden_states) # Split query from key_value query, key_value = qkv.split( [self.head_size * self.num_heads, 2 * self.head_size], dim=1 ) # Prepare query and key_value for indexing query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) # Prefill if layer_past_present_indices is None: # Copy to layer past layer_past[...] = key_value # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) # output attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, key_value[:, 0], key_value[:, 1], attn_output, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, self.softmax_scale, False, True, False, 0, None, ) # Decode else: # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = key_value # Expand from 1 to num_heads key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) # output attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, key_value[:, 0], key_value[:, 1], attn_output, cu_seqlens_q, cu_seqlens, 1, max_s, 0.0, self.softmax_scale, False, False, False, 0, None, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) class MLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() act = config.activation_function self.act = ( ACT2FN[act] if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, approximate="tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none", ) ) self.c_fc = load_col( config, prefix=f"{prefix}.c_fc", weights=weights, bias=True ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) return hidden_states class Block(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" self.ln_1 = FastLayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) self.ln_2 = FastLayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon ) self.attn = FlashMQAttention( prefix=f"{prefix}.attn", config=config, weights=weights, ) self.mlp = MLP( prefix=f"{prefix}.mlp", config=config, weights=weights, ) def forward( self, hidden_states, residual, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q, ) hidden_states, residual = self.ln_2(hidden_states, residual) mlp_output = self.mlp(hidden_states) return mlp_output, residual class FlashSantacoderModel(nn.Module): def __init__(self, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( prefix="transformer.wte", weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( prefix="transformer.wpe", weights=weights, reduce=False, ) self.h = nn.ModuleList( [ Block( layer_id, config, weights, ) for layer_id in range(config.num_hidden_layers) ] ) self.ln_f = FastLayerNorm.load( prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon ) self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads def forward( self, input_ids, position_ids, cu_seqlens, cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: # Create past tensor past_key_values = hidden_states.new_empty( ( len(self.h), len(hidden_states) if pre_allocate_past_size is None else pre_allocate_past_size, 2, 1, self.head_size, ) ) layer_past_present_indices = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 slice_past_index = None residual = None for i, layer in enumerate(self.h): # We added padding that we now need to slice layer_past_key_values = ( past_key_values[i] if slice_past_index is None else past_key_values[i, :slice_past_index] ) hidden_states, residual = layer( hidden_states, residual, cu_seqlens, max_s, layer_past_key_values, layer_past_present_indices, cu_seqlens_q, ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() self.transformer = FlashSantacoderModel(config, weights) self.lm_head = TensorParallelHead.load( config, prefix="transformer.wte", weights=weights ) def forward( self, input_ids, position_ids, cu_seqlens, cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, position_ids, cu_seqlens, cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) return logits, present