Fix flax from_pretrained pytorch weight check (#603)

This commit is contained in:
Mishig Davaadorj 2022-09-21 11:17:15 +02:00 committed by GitHub
parent fb2fbab10b
commit f810060006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -307,7 +307,7 @@ class FlaxModelMixin:
# Load model
if os.path.isdir(pretrained_model_name_or_path):
if from_pt:
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
)
@ -315,8 +315,8 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
# Check if pytorch weights exist instead
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
" using `from_pt=True`."