Fix flax from_pretrained pytorch weight check (#603)
This commit is contained in:
parent
fb2fbab10b
commit
f810060006
|
@ -307,7 +307,7 @@ class FlaxModelMixin:
|
|||
# Load model
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_pt:
|
||||
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
if not os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
)
|
||||
|
@ -315,8 +315,8 @@ class FlaxModelMixin:
|
|||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
# Check if pytorch weights exist instead
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
|
||||
" using `from_pt=True`."
|
||||
|
|
Loading…
Reference in New Issue