diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py new file mode 100644 index 00000000..0e4550b7 --- /dev/null +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -0,0 +1,196 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +import torch +from torch.onnx import export + +from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline +from diffusers.onnx_utils import OnnxRuntimeModel +from packaging import version + + +is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") + + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +@torch.no_grad() +def convert_models(model_path: str, output_path: str, opset: int): + pipeline = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=True) + output_path = Path(output_path) + + # TEXT ENCODER + text_input = pipeline.tokenizer( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=(text_input.input_ids.to(torch.int32)), + output_path=output_path / "text_encoder" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + + # UNET + onnx_export( + pipeline.unet, + model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), + output_path=output_path / "unet" / "model.onnx", + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + output_names=["out_sample"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, + opset=opset, + use_external_data_format=True, # UNet is > 2GB, so the weights need to be split + ) + + # VAE ENCODER + vae_encoder = pipeline.vae + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + onnx_export( + vae_encoder, + model_args=(torch.randn(1, 3, 512, 512), False), + output_path=output_path / "vae_encoder" / "model.onnx", + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=(torch.randn(1, 4, 64, 64), False), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # SAFETY CHECKER + safety_checker = pipeline.safety_checker + safety_checker.forward = safety_checker.forward_onnx + onnx_export( + pipeline.safety_checker, + model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)), + output_path=output_path / "safety_checker" / "model.onnx", + ordered_input_names=["clip_input", "images"], + output_names=["out_images", "has_nsfw_concepts"], + dynamic_axes={ + "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "images": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + onnx_pipeline = StableDiffusionOnnxPipeline( + vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), + scheduler=pipeline.scheduler, + safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"), + feature_extractor=pipeline.feature_extractor, + ) + + onnx_pipeline.save_pretrained(output_path) + print("ONNX pipeline saved to", output_path) + + _ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider") + print("ONNX pipeline is loadable") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--opset", + default=14, + type=str, + help="The version of the ONNX operator set to use.", + ) + + args = parser.parse_args() + + convert_models(args.model_path, args.output_path, args.opset) diff --git a/setup.py b/setup.py index 7b71bd70..74821fcb 100644 --- a/setup.py +++ b/setup.py @@ -170,7 +170,7 @@ extras = {} extras["quality"] = ["black==22.3", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] extras["docs"] = ["hf-doc-builder"] extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] -extras["test"] = ["datasets", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"] +extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"] extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] install_requires = [ diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3b37e198..215bfdf3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,10 +1,17 @@ -from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available +from .utils import ( + is_inflect_available, + is_onnx_available, + is_scipy_available, + is_transformers_available, + is_unidecode_available, +) __version__ = "0.3.0.dev0" from .modeling_utils import ModelMixin from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from .onnx_utils import OnnxRuntimeModel from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, @@ -44,3 +51,9 @@ if is_transformers_available(): ) else: from .utils.dummy_transformers_objects import * # noqa F403 + + +if is_transformers_available() and is_onnx_available(): + from .pipelines import StableDiffusionOnnxPipeline +else: + from .utils.dummy_transformers_and_onnx_objects import * # noqa F403 diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py new file mode 100644 index 00000000..e840565d --- /dev/null +++ b/src/diffusers/onnx_utils.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional, Union + +import numpy as np + +from huggingface_hub import hf_hub_download + +from .utils import is_onnx_available, logging + + +if is_onnx_available(): + import onnxruntime as ort + + +ONNX_WEIGHTS_NAME = "model.onnx" + + +logger = logging.get_logger(__name__) + + +class OnnxRuntimeModel: + base_model_prefix = "onnx_model" + + def __init__(self, model=None, **kwargs): + logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") + self.model = model + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + + def __call__(self, **kwargs): + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.model.run(None, inputs) + + @staticmethod + def load_model(path: Union[str, Path], provider=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider]) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the + latest_model_name. + + Arguments: + save_directory (`str` or `Path`): + Directory where to save the model file. + file_name(`str`, *optional*): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the + model with a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + if not src_path.samefile(dst_path): + shutil.copyfile(src_path, dst_path) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + """ + Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class + method.: + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + **kwargs, + ): + """ + Load a model from a directory or the HF Hub. + + Arguments: + model_id (`str` or `Path`): + Directory from which to load + use_auth_token (`str` or `bool`): + Is needed to load models from a private or gated repository + revision (`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (`Union[str, Path]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load + different model files from the same repository or directory. + provider(`str`): + The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. + kwargs (`Dict`, *optional*): + kwargs will be passed to the model during initialization + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) + return cls(model=model, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + force_download: bool = True, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 1604024b..84ee9e20 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -23,6 +23,7 @@ from typing import List, Optional, Union import numpy as np import torch +import diffusers import PIL from huggingface_hub import snapshot_download from PIL import Image @@ -43,6 +44,7 @@ LOADABLE_CLASSES = { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_config", "from_config"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], @@ -278,6 +280,7 @@ class DiffusionPipeline(ConfigMixin): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + provider = kwargs.pop("provider", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -372,6 +375,8 @@ class DiffusionPipeline(ConfigMixin): loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 40ac3467..3e2aeb4f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_transformers_available +from ..utils import is_onnx_available, is_transformers_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -14,3 +14,6 @@ if is_transformers_available(): StableDiffusionInpaintPipeline, StableDiffusionPipeline, ) + +if is_transformers_available() and is_onnx_available(): + from .stable_diffusion import StableDiffusionOnnxPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 8bfa394c..5ffda93f 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from dataclasses import dataclass from typing import List, Union @@ -7,7 +6,7 @@ import numpy as np import PIL from PIL import Image -from ...utils import BaseOutput, is_transformers_available +from ...utils import BaseOutput, is_onnx_available, is_transformers_available @dataclass @@ -33,3 +32,6 @@ if is_transformers_available(): from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .safety_checker import StableDiffusionSafetyChecker + +if is_transformers_available() and is_onnx_available(): + from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py new file mode 100644 index 00000000..7ff3ff22 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -0,0 +1,165 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np + +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...onnx_utils import OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput + + +class StableDiffusionOnnxPipeline(DiffusionPipeline): + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + def __init__( + self, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.register_modules( + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, 4, height // 8, width // 8) + if latents is None: + latents = np.random.randn(*latents_shape).astype(np.float32) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + # run safety checker + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 5a315c33..09de92ee 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -13,7 +13,7 @@ logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) - return torch.mm(normalized_image_embeds, normalized_text_embeds.T) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class StableDiffusionSafetyChecker(PreTrainedModel): @@ -78,3 +78,29 @@ class StableDiffusionSafetyChecker(PreTrainedModel): ) return images, has_nsfw_concepts + + @torch.inference_mode() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f9172e8d..c00a28e1 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -25,6 +25,7 @@ from .import_utils import ( is_flax_available, is_inflect_available, is_modelcards_available, + is_onnx_available, is_scipy_available, is_tf_available, is_torch_available, diff --git a/src/diffusers/utils/dummy_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_transformers_and_onnx_objects.py new file mode 100644 index 00000000..2e34b5ce --- /dev/null +++ b/src/diffusers/utils/dummy_transformers_and_onnx_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class StableDiffusionOnnxPipeline(metaclass=DummyObject): + _backends = ["transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["transformers", "onnx"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 05068b6d..1f5e95ad 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -136,6 +136,14 @@ except importlib_metadata.PackageNotFoundError: _modelcards_available = False +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +try: + _onnxruntime_version = importlib_metadata.version("onnxruntime") + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") +except importlib_metadata.PackageNotFoundError: + _onnx_available = False + + _scipy_available = importlib.util.find_spec("scipy") is not None try: _scipy_version = importlib_metadata.version("scipy") @@ -172,6 +180,10 @@ def is_modelcards_available(): return _modelcards_available +def is_onnx_available(): + return _onnx_available + + def is_scipy_available(): return _scipy_available @@ -194,6 +206,12 @@ PYTORCH_IMPORT_ERROR = """ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install @@ -223,6 +241,7 @@ BACKENDS_MAPPING = OrderedDict( [ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 859a1e1c..280bffc2 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -40,6 +40,7 @@ from diffusers import ( ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionOnnxPipeline, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, @@ -1277,3 +1278,23 @@ class PipelineTesterMixin(unittest.TestCase): assert sampled_array.shape == (512, 768, 3) assert np.max(np.abs(sampled_array - expected_array)) < 1e-3 + + @slow + def test_stable_diffusion_onnx(self): + from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models + + with tempfile.TemporaryDirectory() as tmpdirname: + convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14) + + sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider") + + prompt = "A painting of a squirrel eating a burger" + np.random.seed(0) + output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=20, output_type="np") + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3