From 37c9d789aa462f8521902d01b8881425c2c0a62f Mon Sep 17 00:00:00 2001 From: SkyTNT Date: Fri, 16 Sep 2022 18:13:22 +0800 Subject: [PATCH] Fix is_onnx_available (#440) * Fix is_onnx_available Fix: If user install onnxruntime-gpu, is_onnx_available() will return False. * add more onnxruntime candidates * Run `make style` Co-authored-by: anton-l --- src/diffusers/utils/import_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 1f5e95ad..de344d07 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -137,11 +137,19 @@ except importlib_metadata.PackageNotFoundError: _onnx_available = importlib.util.find_spec("onnxruntime") is not None -try: - _onnxruntime_version = importlib_metadata.version("onnxruntime") - logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") -except importlib_metadata.PackageNotFoundError: - _onnx_available = False +if _onnx_available: + candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") _scipy_available = importlib.util.find_spec("scipy") is not None