diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 137a977b..796e23a4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -205,10 +205,16 @@ impl Config { let intermediate_size = self.intermediate_size? as u64; let num_layers = self.num_layers? as u64; - let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size; - let o_flops = 2 * num_kv_heads * head_dim * hidden_size; + let q_flops = 2 * num_heads * head_dim * hidden_size; + let k_flops = 2 * num_kv_heads * head_dim * hidden_size; + let v_flops = 2 * num_kv_heads * head_dim * hidden_size; + let attn_flops = 2 * num_heads * head_dim * hidden_size; + let o_flops = 2 * num_heads * head_dim * hidden_size; + let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops; + let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size; - let layer_flops = attn_flops + o_flops + gate_up_down_flops; + + let layer_flops = attn_layer_flops + gate_up_down_flops; let total = layer_flops * num_layers; Some(total) }