Updated the flops calculation (checked with fvcore).
This commit is contained in:
parent
9e2a4a52b2
commit
627862c11a
|
@ -205,10 +205,16 @@ impl Config {
|
|||
let intermediate_size = self.intermediate_size? as u64;
|
||||
let num_layers = self.num_layers? as u64;
|
||||
|
||||
let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size;
|
||||
let o_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||
let q_flops = 2 * num_heads * head_dim * hidden_size;
|
||||
let k_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||
let v_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||
let attn_flops = 2 * num_heads * head_dim * hidden_size;
|
||||
let o_flops = 2 * num_heads * head_dim * hidden_size;
|
||||
let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops;
|
||||
|
||||
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
||||
let layer_flops = attn_flops + o_flops + gate_up_down_flops;
|
||||
|
||||
let layer_flops = attn_layer_flops + gate_up_down_flops;
|
||||
let total = layer_flops * num_layers;
|
||||
Some(total)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue