Flax `from_pretrained`: clean up `mismatched_keys`. (#630)

Flax from_pretrained: clean up `mismatched_keys`.

Originally removed in 73e0bc692c5761e55faff39c80a26d7a3cfc748c.
This commit is contained in:
Pedro Cuenca 2022-09-29 16:06:19 +02:00 committed by GitHub
parent 84b9df57a7
commit f10576ad5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 17 deletions

View File

@ -436,9 +436,6 @@ class FlaxModelMixin:
) )
cls._missing_keys = missing_keys cls._missing_keys = missing_keys
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
for key in state.keys(): for key in state.keys():
if key in shape_state and state[key].shape != shape_state[key].shape: if key in shape_state and state[key].shape != shape_state[key].shape:
raise ValueError( raise ValueError(
@ -466,26 +463,13 @@ class FlaxModelMixin:
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
) )
elif len(mismatched_keys) == 0: else:
logger.info( logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further" f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training." " training."
) )
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
# dictionary of key: dtypes for the model params # dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state) param_dtypes = jax.tree_map(lambda x: x.dtype, state)