[ONNX] Collate the external weights, speed up loading from the hub (#610)

This commit is contained in:
Anton Lozhkov 2022-09-21 22:26:30 +02:00 committed by GitHub
parent a9fdb3de9e
commit 6bd005ebbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 13 deletions

View File

@ -13,11 +13,14 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import shutil
from pathlib import Path from pathlib import Path
import torch import torch
from torch.onnx import export from torch.onnx import export
import onnx
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version from packaging import version
@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
) )
# UNET # UNET
unet_path = output_path / "unet" / "model.onnx"
onnx_export( onnx_export(
pipeline.unet, pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx", output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={ dynamic_axes={
@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
opset=opset, opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
) )
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)
# VAE ENCODER # VAE ENCODER
vae_encoder = pipeline.vae vae_encoder = pipeline.vae

View File

@ -90,8 +90,10 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6", "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6", "jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4", "modelcards>=0.1.4",
"numpy", "numpy",
"onnxruntime",
"onnxruntime-gpu",
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
@ -100,6 +102,7 @@ _deps = [
"requests", "requests",
"tensorboard", "tensorboard",
"torch>=1.4", "torch>=1.4",
"torchvision",
"transformers>=4.21.0", "transformers>=4.21.0",
] ]
@ -171,10 +174,20 @@ extras = {}
extras = {} extras = {}
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = ["hf-doc-builder"] extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"] extras["test"] = deps_list(
"datasets",
"onnxruntime",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
"scipy",
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch") extras["torch"] = deps_list("torch")
if os.name == "nt": # windows if os.name == "nt": # windows

View File

@ -15,8 +15,10 @@ deps = {
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6", "jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4", "modelcards": "modelcards>=0.1.4",
"numpy": "numpy", "numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
@ -25,5 +27,6 @@ deps = {
"requests": "requests", "requests": "requests",
"tensorboard": "tensorboard", "tensorboard": "tensorboard",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.21.0", "transformers": "transformers>=4.21.0",
} }

View File

@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_stable_diffusion_onnx(self): def test_stable_diffusion_onnx(self):
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
with tempfile.TemporaryDirectory() as tmpdirname: )
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
np.random.seed(0) np.random.seed(0)