Flax `from_pretrained`: clean up `mismatched_keys`. (#630)
Flax from_pretrained: clean up `mismatched_keys`. Originally removed in 73e0bc692c5761e55faff39c80a26d7a3cfc748c.
This commit is contained in:
parent
84b9df57a7
commit
f10576ad5c
|
@ -436,9 +436,6 @@ class FlaxModelMixin:
|
|||
)
|
||||
cls._missing_keys = missing_keys
|
||||
|
||||
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys = []
|
||||
for key in state.keys():
|
||||
if key in shape_state and state[key].shape != shape_state[key].shape:
|
||||
raise ValueError(
|
||||
|
@ -466,26 +463,13 @@ class FlaxModelMixin:
|
|||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
elif len(mismatched_keys) == 0:
|
||||
else:
|
||||
logger.info(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
||||
" training."
|
||||
)
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
# dictionary of key: dtypes for the model params
|
||||
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
|
||||
|
|
Loading…
Reference in New Issue