Updated the flops calculation (checked with fvcore).
This commit is contained in:
parent
9e2a4a52b2
commit
627862c11a
|
@ -205,10 +205,16 @@ impl Config {
|
||||||
let intermediate_size = self.intermediate_size? as u64;
|
let intermediate_size = self.intermediate_size? as u64;
|
||||||
let num_layers = self.num_layers? as u64;
|
let num_layers = self.num_layers? as u64;
|
||||||
|
|
||||||
let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size;
|
let q_flops = 2 * num_heads * head_dim * hidden_size;
|
||||||
let o_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
let k_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||||
|
let v_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||||
|
let attn_flops = 2 * num_heads * head_dim * hidden_size;
|
||||||
|
let o_flops = 2 * num_heads * head_dim * hidden_size;
|
||||||
|
let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops;
|
||||||
|
|
||||||
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
||||||
let layer_flops = attn_flops + o_flops + gate_up_down_flops;
|
|
||||||
|
let layer_flops = attn_layer_flops + gate_up_down_flops;
|
||||||
let total = layer_flops * num_layers;
|
let total = layer_flops * num_layers;
|
||||||
Some(total)
|
Some(total)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue