From 9fab7c6665783ec2bf4726bb4bf6bb39a6f9a467 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 11 Nov 2024 14:31:32 +0100 Subject: [PATCH] Updated the flops calculation (checked with fvcore). --- launcher/src/main.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 88afb41d..f3f31f66 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -205,10 +205,16 @@ impl Config { let intermediate_size = self.intermediate_size? as u64; let num_layers = self.num_layers? as u64; - let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size; - let o_flops = 2 * num_kv_heads * head_dim * hidden_size; + let q_flops = 2 * num_heads * head_dim * hidden_size; + let k_flops = 2 * num_kv_heads * head_dim * hidden_size; + let v_flops = 2 * num_kv_heads * head_dim * hidden_size; + let attn_flops = 2 * num_heads * head_dim * hidden_size; + let o_flops = 2 * num_heads * head_dim * hidden_size; + let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops; + let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size; - let layer_flops = attn_flops + o_flops + gate_up_down_flops; + + let layer_flops = attn_layer_flops + gate_up_down_flops; let total = layer_flops * num_layers; Some(total) }