diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d3c719d..4be8b98 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -39,6 +39,7 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, PositionRotaryEmbedding, TensorParallelHead, + get_linear, ) @@ -59,7 +60,8 @@ class LlamaRMSNorm(nn.Module): hidden_states += residual residual = hidden_states - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt( variance + self.variance_epsilon ) @@ -94,6 +96,27 @@ class LlamaRMSNorm(nn.Module): return normed_hidden_states, res +def _load_gqa(config, prefix: str, weights): + w = [ + weights.get_sharded(f"{prefix}.q_proj.weight", dim=0), + weights.get_sharded(f"{prefix}.k_proj.weight", dim=0), + weights.get_sharded(f"{prefix}.v_proj.weight", dim=0), + ] + weight = torch.cat(w, dim=0) + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + bias = None + assert config.hidden_size % config.num_attention_heads == 0 + head_size = config.hidden_size // config.num_attention_heads + assert config.num_attention_heads % weights.process_group.size() == 0 + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -118,22 +141,29 @@ class FlashLlamaAttention(torch.nn.Module): f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() - self.query_key_value = TensorParallelColumnLinear.load_multi( - config, - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - dim=0, - weights=weights, - bias=False, + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() ) + if config.num_attention_heads != config.num_key_value_heads: + self.query_key_value = _load_gqa(config, prefix, weights) + else: + self.query_key_value = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( - 0, self.num_heads, dtype=torch.int32, device=weights.device - ) + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) def forward( self, @@ -148,26 +178,33 @@ class FlashLlamaAttention(torch.nn.Module): max_s, ): qkv = self.query_key_value(hidden_states) - qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) - # Inplace rotary - self.rotary_emb(qkv[:, 0], cos, sin) - self.rotary_emb(qkv[:, 1], cos, sin) + self.rotary_emb(query, cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) vllm_cache_ops.reshape_and_cache( - qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots ) # output tensor - attn_output = torch.empty_like(qkv[:, 0]) + attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: # flash attention attention( - qkv[:, 0], - qkv[:, 1], - qkv[:, 2], + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, max_s, @@ -179,7 +216,7 @@ class FlashLlamaAttention(torch.nn.Module): block_size = kv_cache[1].shape[3] vllm_attention_ops.single_query_cached_kv_attention( attn_output, - qkv[:, 0], + query, kv_cache[0], kv_cache[1], self.kv_head_mapping, @@ -316,6 +353,7 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 417ccab..088b50b 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -69,7 +69,7 @@ class FlashLlama(FlashCausalLM): model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, + num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, dtype=dtype, device=device,