diff --git a/modules/sd_vae.py b/modules/sd_vae.py index b7176125c..e4ff29946 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -85,10 +85,10 @@ def refresh_vae_list(): def find_vae_near_checkpoint(checkpoint_file): - checkpoint_path = os.path.splitext(checkpoint_file)[0] - for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]: - if os.path.isfile(vae_location): - return vae_location + checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] + for vae_file in vae_dict.values(): + if os.path.basename(vae_file).startswith(checkpoint_path): + return vae_file return None