[ONNX] Stable Diffusion exporter and pipeline (#399)
* initial export and design * update imports * custom prover, import fixes * Update src/diffusers/onnx_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/onnx_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * remove push_to_hub * Update src/diffusers/onnx_utils.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * remove torch_device * numpify the rest of the pipeline * torchify the safety checker * revert tensor * Code review suggestions + quality * fix tests * fix provider, add an end-to-end test * style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
7bcc873bb5
commit
8d9c4a531b
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
|
||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
||||
from packaging import version
|
||||
|
||||
|
||||
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
output_path: Path,
|
||||
ordered_input_names,
|
||||
output_names,
|
||||
dynamic_axes,
|
||||
opset,
|
||||
use_external_data_format=False,
|
||||
):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||
# so we check the torch version for backwards compatibility
|
||||
if is_torch_less_than_1_11:
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_models(model_path: str, output_path: str, opset: int):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=True)
|
||||
output_path = Path(output_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
text_input = pipeline.tokenizer(
|
||||
"A sample prompt",
|
||||
padding="max_length",
|
||||
max_length=pipeline.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
onnx_export(
|
||||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(text_input.input_ids.to(torch.int32)),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# UNET
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
|
||||
output_path=output_path / "unet" / "model.onnx",
|
||||
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
|
||||
output_names=["out_sample"], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=opset,
|
||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(torch.randn(1, 3, 512, 512), False),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(torch.randn(1, 4, 64, 64), False),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# SAFETY CHECKER
|
||||
safety_checker = pipeline.safety_checker
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
onnx_pipeline = StableDiffusionOnnxPipeline(
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
|
||||
feature_extractor=pipeline.feature_extractor,
|
||||
)
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
print("ONNX pipeline saved to", output_path)
|
||||
|
||||
_ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
|
||||
print("ONNX pipeline is loadable")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
default=14,
|
||||
type=str,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_models(args.model_path, args.output_path, args.opset)
|
2
setup.py
2
setup.py
|
@ -170,7 +170,7 @@ extras = {}
|
|||
extras["quality"] = ["black==22.3", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
|
||||
extras["docs"] = ["hf-doc-builder"]
|
||||
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
||||
extras["test"] = ["datasets", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
|
||||
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
|
||||
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
|
||||
|
||||
install_requires = [
|
||||
|
|
|
@ -1,10 +1,17 @@
|
|||
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available
|
||||
from .utils import (
|
||||
is_inflect_available,
|
||||
is_onnx_available,
|
||||
is_scipy_available,
|
||||
is_transformers_available,
|
||||
is_unidecode_available,
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.3.0.dev0"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
|
@ -44,3 +51,9 @@ if is_transformers_available():
|
|||
)
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import * # noqa F403
|
||||
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .pipelines import StableDiffusionOnnxPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .utils import is_onnx_available, logging
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class OnnxRuntimeModel:
|
||||
base_model_prefix = "onnx_model"
|
||||
|
||||
def __init__(self, model=None, **kwargs):
|
||||
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
|
||||
self.model = model
|
||||
self.model_save_dir = kwargs.get("model_save_dir", None)
|
||||
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
return self.model.run(None, inputs)
|
||||
|
||||
@staticmethod
|
||||
def load_model(path: Union[str, Path], provider=None):
|
||||
"""
|
||||
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
||||
|
||||
Arguments:
|
||||
path (`str` or `Path`):
|
||||
Directory from which to load
|
||||
provider(`str`, *optional*):
|
||||
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
|
||||
"""
|
||||
if provider is None:
|
||||
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
return ort.InferenceSession(path, providers=[provider])
|
||||
|
||||
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
[`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
|
||||
latest_model_name.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `Path`):
|
||||
Directory where to save the model file.
|
||||
file_name(`str`, *optional*):
|
||||
Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
|
||||
model with a different name.
|
||||
"""
|
||||
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
||||
|
||||
src_path = self.model_save_dir.joinpath(self.latest_model_name)
|
||||
dst_path = Path(save_directory).joinpath(model_file_name)
|
||||
if not src_path.samefile(dst_path):
|
||||
shutil.copyfile(src_path, dst_path)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
|
||||
method.:
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# saving model weights/files
|
||||
self._save_pretrained(save_directory, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
use_auth_token: Optional[Union[bool, str, None]] = None,
|
||||
revision: Optional[Union[str, None]] = None,
|
||||
force_download: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load a model from a directory or the HF Hub.
|
||||
|
||||
Arguments:
|
||||
model_id (`str` or `Path`):
|
||||
Directory from which to load
|
||||
use_auth_token (`str` or `bool`):
|
||||
Is needed to load models from a private or gated repository
|
||||
revision (`str`):
|
||||
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
|
||||
cache_dir (`Union[str, Path]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
file_name(`str`):
|
||||
Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
|
||||
different model files from the same repository or directory.
|
||||
provider(`str`):
|
||||
The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
|
||||
kwargs (`Dict`, *optional*):
|
||||
kwargs will be passed to the model during initialization
|
||||
"""
|
||||
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
||||
# load model from local directory
|
||||
if os.path.isdir(model_id):
|
||||
model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
|
||||
kwargs["model_save_dir"] = Path(model_id)
|
||||
# load model from hub
|
||||
else:
|
||||
# download model
|
||||
model_cache_path = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=model_file_name,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
)
|
||||
kwargs["model_save_dir"] = Path(model_cache_path).parent
|
||||
kwargs["latest_model_name"] = Path(model_cache_path).name
|
||||
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
|
||||
return cls(model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
force_download: bool = True,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
revision = None
|
||||
if len(str(model_id).split("@")) == 2:
|
||||
model_id, revision = model_id.split("@")
|
||||
|
||||
return cls._from_pretrained(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
use_auth_token=use_auth_token,
|
||||
**model_kwargs,
|
||||
)
|
|
@ -23,6 +23,7 @@ from typing import List, Optional, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
@ -43,6 +44,7 @@ LOADABLE_CLASSES = {
|
|||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_config", "from_config"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
|
@ -278,6 +280,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
|
@ -372,6 +375,8 @@ class DiffusionPipeline(ConfigMixin):
|
|||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from ..utils import is_transformers_available
|
||||
from ..utils import is_onnx_available, is_transformers_available
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
|
@ -14,3 +14,6 @@ if is_transformers_available():
|
|||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .stable_diffusion import StableDiffusionOnnxPipeline
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
# flake8: noqa
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
|
@ -7,7 +6,7 @@ import numpy as np
|
|||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_transformers_available
|
||||
from ...utils import BaseOutput, is_onnx_available, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -33,3 +32,6 @@ if is_transformers_available():
|
|||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
|
||||
|
|
|
@ -0,0 +1,165 @@
|
|||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...onnx_utils import OnnxRuntimeModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
tokenizer: CLIPTokenizer
|
||||
unet: OnnxRuntimeModel
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_decoder: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: OnnxRuntimeModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("np")
|
||||
self.register_modules(
|
||||
vae_decoder=vae_decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
eta: Optional[float] = 0.0,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, 4, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = np.random.randn(*latents_shape).astype(np.float32)
|
||||
elif latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae_decoder(latent_sample=latents)[0]
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
# run safety checker
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
@ -13,7 +13,7 @@ logger = logging.get_logger(__name__)
|
|||
def cosine_distance(image_embeds, text_embeds):
|
||||
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
||||
normalized_text_embeds = nn.functional.normalize(text_embeds)
|
||||
return torch.mm(normalized_image_embeds, normalized_text_embeds.T)
|
||||
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
|
||||
|
||||
|
||||
class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
|
@ -78,3 +78,29 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
|||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
|
||||
|
||||
# increase this value to create a stronger `nsfw` filter
|
||||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
|
||||
# special_scores = special_scores.round(decimals=3)
|
||||
special_care = torch.any(special_scores > 0, dim=1)
|
||||
special_adjustment = special_care * 0.01
|
||||
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
|
||||
|
||||
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
||||
# concept_scores = concept_scores.round(decimals=3)
|
||||
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
|
||||
|
||||
images[has_nsfw_concepts] = 0.0 # black image
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
|
|
@ -25,6 +25,7 @@ from .import_utils import (
|
|||
is_flax_available,
|
||||
is_inflect_available,
|
||||
is_modelcards_available,
|
||||
is_onnx_available,
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
|
||||
_backends = ["transformers", "onnx"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["transformers", "onnx"])
|
|
@ -136,6 +136,14 @@ except importlib_metadata.PackageNotFoundError:
|
|||
_modelcards_available = False
|
||||
|
||||
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
try:
|
||||
_onnxruntime_version = importlib_metadata.version("onnxruntime")
|
||||
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_scipy_available = importlib.util.find_spec("scipy") is not None
|
||||
try:
|
||||
_scipy_version = importlib_metadata.version("scipy")
|
||||
|
@ -172,6 +180,10 @@ def is_modelcards_available():
|
|||
return _modelcards_available
|
||||
|
||||
|
||||
def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_scipy_available():
|
||||
return _scipy_available
|
||||
|
||||
|
@ -194,6 +206,12 @@ PYTORCH_IMPORT_ERROR = """
|
|||
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
ONNX_IMPORT_ERROR = """
|
||||
{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
|
||||
install onnxruntime`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
SCIPY_IMPORT_ERROR = """
|
||||
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
|
||||
|
@ -223,6 +241,7 @@ BACKENDS_MAPPING = OrderedDict(
|
|||
[
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
||||
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
|
||||
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
||||
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
||||
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
|
|
|
@ -40,6 +40,7 @@ from diffusers import (
|
|||
ScoreSdeVeScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
|
@ -1277,3 +1278,23 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
assert sampled_array.shape == (512, 768, 3)
|
||||
assert np.max(np.abs(sampled_array - expected_array)) < 1e-3
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
|
||||
|
||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
np.random.seed(0)
|
||||
output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=20, output_type="np")
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
|
Loading…
Reference in New Issue