diff --git a/scripts/convert_vae_diff_to_onnx.py b/scripts/convert_vae_diff_to_onnx.py index 4d6053ed..e023e04b 100644 --- a/scripts/convert_vae_diff_to_onnx.py +++ b/scripts/convert_vae_diff_to_onnx.py @@ -13,16 +13,13 @@ # limitations under the License. import argparse -import os -import shutil from pathlib import Path import torch +from packaging import version from torch.onnx import export -import onnx -from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, AutoencoderKL -from packaging import version +from diffusers import AutoencoderKL is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") @@ -79,9 +76,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path = Path(output_path) # VAE DECODER - vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae") + vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae") vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part vae_decoder.forward = vae_decoder.decode onnx_export(