[ONNX] Collate the external weights, speed up loading from the hub (#610)
This commit is contained in:
parent
a9fdb3de9e
commit
6bd005ebbe
|
@ -13,11 +13,14 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
|
import onnx
|
||||||
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
|
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
|
||||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
from diffusers.onnx_utils import OnnxRuntimeModel
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
# UNET
|
# UNET
|
||||||
|
unet_path = output_path / "unet" / "model.onnx"
|
||||||
onnx_export(
|
onnx_export(
|
||||||
pipeline.unet,
|
pipeline.unet,
|
||||||
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
|
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
|
||||||
output_path=output_path / "unet" / "model.onnx",
|
output_path=unet_path,
|
||||||
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
|
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
|
||||||
output_names=["out_sample"], # has to be different from "sample" for correct tracing
|
output_names=["out_sample"], # has to be different from "sample" for correct tracing
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
|
@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||||
opset=opset,
|
opset=opset,
|
||||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||||
)
|
)
|
||||||
|
unet_model_path = str(unet_path.absolute().as_posix())
|
||||||
|
unet_dir = os.path.dirname(unet_model_path)
|
||||||
|
unet = onnx.load(unet_model_path)
|
||||||
|
# clean up existing tensor files
|
||||||
|
shutil.rmtree(unet_dir)
|
||||||
|
os.mkdir(unet_dir)
|
||||||
|
# collate external tensor files into one
|
||||||
|
onnx.save_model(
|
||||||
|
unet,
|
||||||
|
unet_model_path,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location="weights.pb",
|
||||||
|
convert_attribute=False,
|
||||||
|
)
|
||||||
|
|
||||||
# VAE ENCODER
|
# VAE ENCODER
|
||||||
vae_encoder = pipeline.vae
|
vae_encoder = pipeline.vae
|
||||||
|
|
23
setup.py
23
setup.py
|
@ -90,8 +90,10 @@ _deps = [
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||||
"jaxlib>=0.1.65,<=0.3.6",
|
"jaxlib>=0.1.65,<=0.3.6",
|
||||||
"modelcards==0.1.4",
|
"modelcards>=0.1.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
"onnxruntime",
|
||||||
|
"onnxruntime-gpu",
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
@ -100,6 +102,7 @@ _deps = [
|
||||||
"requests",
|
"requests",
|
||||||
"tensorboard",
|
"tensorboard",
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
|
"torchvision",
|
||||||
"transformers>=4.21.0",
|
"transformers>=4.21.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -171,10 +174,20 @@ extras = {}
|
||||||
|
|
||||||
|
|
||||||
extras = {}
|
extras = {}
|
||||||
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
|
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||||
extras["docs"] = ["hf-doc-builder"]
|
extras["docs"] = deps_list("hf-doc-builder")
|
||||||
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||||
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"]
|
extras["test"] = deps_list(
|
||||||
|
"datasets",
|
||||||
|
"onnxruntime",
|
||||||
|
"onnxruntime-gpu",
|
||||||
|
"pytest",
|
||||||
|
"pytest-timeout",
|
||||||
|
"pytest-xdist",
|
||||||
|
"scipy",
|
||||||
|
"torchvision",
|
||||||
|
"transformers"
|
||||||
|
)
|
||||||
extras["torch"] = deps_list("torch")
|
extras["torch"] = deps_list("torch")
|
||||||
|
|
||||||
if os.name == "nt": # windows
|
if os.name == "nt": # windows
|
||||||
|
|
|
@ -15,8 +15,10 @@ deps = {
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||||
"modelcards": "modelcards==0.1.4",
|
"modelcards": "modelcards>=0.1.4",
|
||||||
"numpy": "numpy",
|
"numpy": "numpy",
|
||||||
|
"onnxruntime": "onnxruntime",
|
||||||
|
"onnxruntime-gpu": "onnxruntime-gpu",
|
||||||
"pytest": "pytest",
|
"pytest": "pytest",
|
||||||
"pytest-timeout": "pytest-timeout",
|
"pytest-timeout": "pytest-timeout",
|
||||||
"pytest-xdist": "pytest-xdist",
|
"pytest-xdist": "pytest-xdist",
|
||||||
|
@ -25,5 +27,6 @@ deps = {
|
||||||
"requests": "requests",
|
"requests": "requests",
|
||||||
"tensorboard": "tensorboard",
|
"tensorboard": "tensorboard",
|
||||||
"torch": "torch>=1.4",
|
"torch": "torch>=1.4",
|
||||||
|
"torchvision": "torchvision",
|
||||||
"transformers": "transformers>=4.21.0",
|
"transformers": "transformers>=4.21.0",
|
||||||
}
|
}
|
||||||
|
|
|
@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_stable_diffusion_onnx(self):
|
def test_stable_diffusion_onnx(self):
|
||||||
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
|
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
)
|
||||||
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
|
|
||||||
|
|
||||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
|
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
Loading…
Reference in New Issue