From 6bd005ebbe3f4c02ed047d22bd93485d0a63089b Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 21 Sep 2022 22:26:30 +0200 Subject: [PATCH] [ONNX] Collate the external weights, speed up loading from the hub (#610) --- ...ert_stable_diffusion_checkpoint_to_onnx.py | 21 ++++++++++++++++- setup.py | 23 +++++++++++++++---- src/diffusers/dependency_versions_table.py | 5 +++- tests/test_pipelines.py | 9 +++----- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index 0e4550b7..beeacfe3 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -13,11 +13,14 @@ # limitations under the License. import argparse +import os +import shutil from pathlib import Path import torch from torch.onnx import export +import onnx from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline from diffusers.onnx_utils import OnnxRuntimeModel from packaging import version @@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int): ) # UNET + unet_path = output_path / "unet" / "model.onnx" onnx_export( pipeline.unet, model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), - output_path=output_path / "unet" / "model.onnx", + output_path=unet_path, ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing dynamic_axes={ @@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int): opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) + unet_model_path = str(unet_path.absolute().as_posix()) + unet_dir = os.path.dirname(unet_model_path) + unet = onnx.load(unet_model_path) + # clean up existing tensor files + shutil.rmtree(unet_dir) + os.mkdir(unet_dir) + # collate external tensor files into one + onnx.save_model( + unet, + unet_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False, + ) # VAE ENCODER vae_encoder = pipeline.vae diff --git a/setup.py b/setup.py index 71e7de77..b3668d2d 100644 --- a/setup.py +++ b/setup.py @@ -90,8 +90,10 @@ _deps = [ "isort>=5.5.4", "jax>=0.2.8,!=0.3.2,<=0.3.6", "jaxlib>=0.1.65,<=0.3.6", - "modelcards==0.1.4", + "modelcards>=0.1.4", "numpy", + "onnxruntime", + "onnxruntime-gpu", "pytest", "pytest-timeout", "pytest-xdist", @@ -100,6 +102,7 @@ _deps = [ "requests", "tensorboard", "torch>=1.4", + "torchvision", "transformers>=4.21.0", ] @@ -171,10 +174,20 @@ extras = {} extras = {} -extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] -extras["docs"] = ["hf-doc-builder"] -extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] -extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"] +extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder") +extras["docs"] = deps_list("hf-doc-builder") +extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") +extras["test"] = deps_list( + "datasets", + "onnxruntime", + "onnxruntime-gpu", + "pytest", + "pytest-timeout", + "pytest-xdist", + "scipy", + "torchvision", + "transformers" +) extras["torch"] = deps_list("torch") if os.name == "nt": # windows diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index f6fb3973..24564f7b 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -15,8 +15,10 @@ deps = { "isort": "isort>=5.5.4", "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", "jaxlib": "jaxlib>=0.1.65,<=0.3.6", - "modelcards": "modelcards==0.1.4", + "modelcards": "modelcards>=0.1.4", "numpy": "numpy", + "onnxruntime": "onnxruntime", + "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", @@ -25,5 +27,6 @@ deps = { "requests": "requests", "tensorboard": "tensorboard", "torch": "torch>=1.4", + "torchvision": "torchvision", "transformers": "transformers>=4.21.0", } diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 145c26e3..71584ecb 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_stable_diffusion_onnx(self): - from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models - - with tempfile.TemporaryDirectory() as tmpdirname: - convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14) - - sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider") + sd_pipe = StableDiffusionOnnxPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True + ) prompt = "A painting of a squirrel eating a burger" np.random.seed(0)