make style
This commit is contained in:
parent
5f826a35fb
commit
f38e3626cd
|
@ -13,16 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
import onnx
|
from diffusers import AutoencoderKL
|
||||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, AutoencoderKL
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
|
|
||||||
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
||||||
|
@ -79,9 +76,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
# VAE DECODER
|
# VAE DECODER
|
||||||
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
|
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
|
||||||
vae_latent_channels = vae_decoder.config.latent_channels
|
vae_latent_channels = vae_decoder.config.latent_channels
|
||||||
vae_out_channels = vae_decoder.config.out_channels
|
|
||||||
# forward only through the decoder part
|
# forward only through the decoder part
|
||||||
vae_decoder.forward = vae_decoder.decode
|
vae_decoder.forward = vae_decoder.decode
|
||||||
onnx_export(
|
onnx_export(
|
||||||
|
|
Loading…
Reference in New Issue