From f10576ad5c9dbd17c59e0af12a26583e3a540e20 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 29 Sep 2022 16:06:19 +0200 Subject: [PATCH] Flax `from_pretrained`: clean up `mismatched_keys`. (#630) Flax from_pretrained: clean up `mismatched_keys`. Originally removed in 73e0bc692c5761e55faff39c80a26d7a3cfc748c. --- src/diffusers/modeling_flax_utils.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 7f1d65e2..80c3fb68 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -436,9 +436,6 @@ class FlaxModelMixin: ) cls._missing_keys = missing_keys - # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys = [] for key in state.keys(): if key in shape_state and state[key].shape != shape_state[key].shape: raise ValueError( @@ -466,26 +463,13 @@ class FlaxModelMixin: f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) - elif len(mismatched_keys) == 0: + else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {model.__class__.__name__} for predictions without further" " training." ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" - " to use it for predictions and inference." - ) # dictionary of key: dtypes for the model params param_dtypes = jax.tree_map(lambda x: x.dtype, state)