make style
This commit is contained in:
parent
5f826a35fb
commit
f38e3626cd
|
@ -13,16 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.onnx import export
|
||||
|
||||
import onnx
|
||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, AutoencoderKL
|
||||
from packaging import version
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
|
||||
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
||||
|
@ -79,9 +76,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
|||
output_path = Path(output_path)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
|
||||
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_decoder.decode
|
||||
onnx_export(
|
||||
|
|
Loading…
Reference in New Issue