make style

This commit is contained in:
Patrick von Platen 2023-03-06 10:40:18 +00:00
parent 5f826a35fb
commit f38e3626cd
1 changed files with 3 additions and 7 deletions

View File

@ -13,16 +13,13 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import shutil
from pathlib import Path from pathlib import Path
import torch import torch
from packaging import version
from torch.onnx import export from torch.onnx import export
import onnx from diffusers import AutoencoderKL
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, AutoencoderKL
from packaging import version
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
@ -79,9 +76,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
output_path = Path(output_path) output_path = Path(output_path)
# VAE DECODER # VAE DECODER
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae") vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
vae_latent_channels = vae_decoder.config.latent_channels vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part # forward only through the decoder part
vae_decoder.forward = vae_decoder.decode vae_decoder.forward = vae_decoder.decode
onnx_export( onnx_export(