16 lines
682 B
Python
16 lines
682 B
Python
|
def estimate_model_size(config: dict):
|
||
|
"""
|
||
|
Estimate the size of a model from its config. No idea if this is correct,
|
||
|
but it allows us to compare models.
|
||
|
:param config:
|
||
|
:return:
|
||
|
"""
|
||
|
vocab_size = config.get('vocab_size')
|
||
|
hidden_size = config.get('hidden_size')
|
||
|
num_hidden_layers = config.get('num_hidden_layers')
|
||
|
intermediate_size = config.get('intermediate_size')
|
||
|
if vocab_size and hidden_size and num_hidden_layers and intermediate_size:
|
||
|
total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
|
||
|
return int(total_params / 1e9)
|
||
|
return 0
|