fix local subfolder

This commit is contained in:
Patrick von Platen 2022-07-15 17:55:20 +00:00
parent 29628acbec
commit 1c14ce9509
3 changed files with 11 additions and 1 deletions

View File

@ -134,6 +134,10 @@ class ConfigMixin:
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
else:
raise EnvironmentError(
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."

View File

@ -348,6 +348,10 @@ class ModelMixin(torch.nn.Module):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."

View File

@ -214,7 +214,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
) ** (0.5)
# full formula (9)
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
prev_sample = (
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
)
return prev_sample