diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index ed62b5fe..e06e7fb7 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -307,7 +307,7 @@ class FlaxModelMixin: # Load model if os.path.isdir(pretrained_model_name_or_path): if from_pt: - if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " ) @@ -315,8 +315,8 @@ class FlaxModelMixin: elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) - # At this stage we don't have a weight file so we will raise an error. - elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + # Check if pytorch weights exist instead + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model" " using `from_pt=True`."