[ONNX] Collate the external weights, speed up loading from the hub (#610)

This commit is contained in:
Anton Lozhkov 2022-09-21 22:26:30 +02:00 committed by GitHub
parent a9fdb3de9e
commit 6bd005ebbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 13 deletions

View File

@ -13,11 +13,14 @@
# limitations under the License.
import argparse
import os
import shutil
from pathlib import Path
import torch
from torch.onnx import export
import onnx
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version
@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
)
# UNET
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx",
output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)
# VAE ENCODER
vae_encoder = pipeline.vae

View File

@ -90,8 +90,10 @@ _deps = [
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4",
"modelcards>=0.1.4",
"numpy",
"onnxruntime",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
@ -100,6 +102,7 @@ _deps = [
"requests",
"tensorboard",
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
]
@ -171,10 +174,20 @@ extras = {}
extras = {}
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"]
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"onnxruntime",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
"scipy",
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")
if os.name == "nt": # windows

View File

@ -15,8 +15,10 @@ deps = {
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
@ -25,5 +27,6 @@ deps = {
"requests": "requests",
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.21.0",
}

View File

@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_stable_diffusion_onnx(self):
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
with tempfile.TemporaryDirectory() as tmpdirname:
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
)
prompt = "A painting of a squirrel eating a burger"
np.random.seed(0)