Flax `from_pretrained`: clean up `mismatched_keys`. (#630)
Flax from_pretrained: clean up `mismatched_keys`. Originally removed in 73e0bc692c5761e55faff39c80a26d7a3cfc748c.
This commit is contained in:
parent
84b9df57a7
commit
f10576ad5c
|
@ -436,9 +436,6 @@ class FlaxModelMixin:
|
||||||
)
|
)
|
||||||
cls._missing_keys = missing_keys
|
cls._missing_keys = missing_keys
|
||||||
|
|
||||||
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
|
||||||
# matching the weights in the model.
|
|
||||||
mismatched_keys = []
|
|
||||||
for key in state.keys():
|
for key in state.keys():
|
||||||
if key in shape_state and state[key].shape != shape_state[key].shape:
|
if key in shape_state and state[key].shape != shape_state[key].shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -466,26 +463,13 @@ class FlaxModelMixin:
|
||||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
elif len(mismatched_keys) == 0:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||||
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
||||||
" training."
|
" training."
|
||||||
)
|
)
|
||||||
if len(mismatched_keys) > 0:
|
|
||||||
mismatched_warning = "\n".join(
|
|
||||||
[
|
|
||||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
|
||||||
for key, shape1, shape2 in mismatched_keys
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
|
||||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
|
||||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
|
||||||
" to use it for predictions and inference."
|
|
||||||
)
|
|
||||||
|
|
||||||
# dictionary of key: dtypes for the model params
|
# dictionary of key: dtypes for the model params
|
||||||
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
||||||
|
|
Loading…
Reference in New Issue