diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 1f5e95ad..de344d07 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -137,11 +137,19 @@ except importlib_metadata.PackageNotFoundError: _onnx_available = importlib.util.find_spec("onnxruntime") is not None -try: - _onnxruntime_version = importlib_metadata.version("onnxruntime") - logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") -except importlib_metadata.PackageNotFoundError: - _onnx_available = False +if _onnx_available: + candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") _scipy_available = importlib.util.find_spec("scipy") is not None