Updated the flops calculation (checked with fvcore).

This commit is contained in:
Nicolas Patry 2024-11-11 14:31:32 +01:00
parent 9e2a4a52b2
commit 627862c11a
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
1 changed files with 9 additions and 3 deletions

View File

@ -205,10 +205,16 @@ impl Config {
let intermediate_size = self.intermediate_size? as u64;
let num_layers = self.num_layers? as u64;
let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size;
let o_flops = 2 * num_kv_heads * head_dim * hidden_size;
let q_flops = 2 * num_heads * head_dim * hidden_size;
let k_flops = 2 * num_kv_heads * head_dim * hidden_size;
let v_flops = 2 * num_kv_heads * head_dim * hidden_size;
let attn_flops = 2 * num_heads * head_dim * hidden_size;
let o_flops = 2 * num_heads * head_dim * hidden_size;
let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops;
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
let layer_flops = attn_flops + o_flops + gate_up_down_flops;
let layer_flops = attn_layer_flops + gate_up_down_flops;
let total = layer_flops * num_layers;
Some(total)
}