14 lines
566 B
Python
14 lines
566 B
Python
|
def estimate_model_size(config: dict):
|
||
|
"""
|
||
|
Estimate the size of a model from its config. No idea if this is correct,
|
||
|
but it allows us to compare models.
|
||
|
:param config:
|
||
|
:return:
|
||
|
"""
|
||
|
vocab_size = config['vocab_size']
|
||
|
hidden_size = config['hidden_size']
|
||
|
num_hidden_layers = config['num_hidden_layers']
|
||
|
intermediate_size = config['intermediate_size']
|
||
|
total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
|
||
|
return int(total_params / 1e9)
|