[ONNX] Collate the external weights, speed up loading from the hub (#610)
This commit is contained in:
parent
a9fdb3de9e
commit
6bd005ebbe
|
@ -13,11 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
import onnx
|
||||
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
|
||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
||||
from packaging import version
|
||||
|
@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
|||
)
|
||||
|
||||
# UNET
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
|
||||
output_path=output_path / "unet" / "model.onnx",
|
||||
output_path=unet_path,
|
||||
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
|
||||
output_names=["out_sample"], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
|
@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
|||
opset=opset,
|
||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
unet_dir = os.path.dirname(unet_model_path)
|
||||
unet = onnx.load(unet_model_path)
|
||||
# clean up existing tensor files
|
||||
shutil.rmtree(unet_dir)
|
||||
os.mkdir(unet_dir)
|
||||
# collate external tensor files into one
|
||||
onnx.save_model(
|
||||
unet,
|
||||
unet_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
convert_attribute=False,
|
||||
)
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
|
|
23
setup.py
23
setup.py
|
@ -90,8 +90,10 @@ _deps = [
|
|||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib>=0.1.65,<=0.3.6",
|
||||
"modelcards==0.1.4",
|
||||
"modelcards>=0.1.4",
|
||||
"numpy",
|
||||
"onnxruntime",
|
||||
"onnxruntime-gpu",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
|
@ -100,6 +102,7 @@ _deps = [
|
|||
"requests",
|
||||
"tensorboard",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.21.0",
|
||||
]
|
||||
|
||||
|
@ -171,10 +174,20 @@ extras = {}
|
|||
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
|
||||
extras["docs"] = ["hf-doc-builder"]
|
||||
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
||||
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"]
|
||||
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||
extras["test"] = deps_list(
|
||||
"datasets",
|
||||
"onnxruntime",
|
||||
"onnxruntime-gpu",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"scipy",
|
||||
"torchvision",
|
||||
"transformers"
|
||||
)
|
||||
extras["torch"] = deps_list("torch")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
|
|
|
@ -15,8 +15,10 @@ deps = {
|
|||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||
"modelcards": "modelcards==0.1.4",
|
||||
"modelcards": "modelcards>=0.1.4",
|
||||
"numpy": "numpy",
|
||||
"onnxruntime": "onnxruntime",
|
||||
"onnxruntime-gpu": "onnxruntime-gpu",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
|
@ -25,5 +27,6 @@ deps = {
|
|||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.21.0",
|
||||
}
|
||||
|
|
|
@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
|
||||
|
||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
|
||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
|
||||
)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
np.random.seed(0)
|
||||
|
|
Loading…
Reference in New Issue