From 7e29b747f935c08f5b85c6eccaad4490023788c9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 17 Jan 2023 21:16:46 +0100 Subject: [PATCH] Check k-diffusion version is at least 0.0.12 (#2022) * Check k-diffusion version is at least 0.0.12 * make style --- setup.py | 2 +- src/diffusers/__init__.py | 1 + src/diffusers/dependency_versions_table.py | 4 ++-- .../pipelines/stable_diffusion/__init__.py | 3 ++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 16 +++++++++++++++- 6 files changed, 22 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 7bd74062..0a6078fe 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ _deps = [ "isort>=5.5.4", "jax>=0.2.8,!=0.3.2", "jaxlib>=0.1.65", - "k-diffusion", + "k-diffusion>=0.0.12", "librosa", "modelcards>=0.1.4", "numpy", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8988cdf1..90af386f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -6,6 +6,7 @@ from .utils import ( is_flax_available, is_inflect_available, is_k_diffusion_available, + is_k_diffusion_version, is_librosa_available, is_onnx_available, is_scipy_available, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 1ef1edc1..a995c207 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,7 +4,7 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", - "black": "black==22.8", + "black": "black==22.12", "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", @@ -15,7 +15,7 @@ deps = { "isort": "isort>=5.5.4", "jax": "jax>=0.2.8,!=0.3.2", "jaxlib": "jaxlib>=0.1.65", - "k-diffusion": "k-diffusion", + "k-diffusion": "k-diffusion>=0.0.12", "librosa": "librosa", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ed4acff..a152e585 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -11,6 +11,7 @@ from ...utils import ( OptionalDependencyNotAvailable, is_flax_available, is_k_diffusion_available, + is_k_diffusion_version, is_onnx_available, is_torch_available, is_transformers_available, @@ -64,7 +65,7 @@ else: try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_version(">=", "0.0.12")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3d059f3f..4f2bc27b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -47,6 +47,7 @@ from .import_utils import ( is_flax_available, is_inflect_available, is_k_diffusion_available, + is_k_diffusion_version, is_librosa_available, is_modelcards_available, is_onnx_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 91d1d091..724cfd9f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -427,12 +427,26 @@ def is_transformers_version(operation: str, version: str): operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): - A string version of PyTorch + A version string """ if not _transformers_available: return False return compare_versions(parse(_transformers_version), operation, version) +def is_k_diffusion_version(operation: str, version: str): + """ + Args: + Compares the current k-diffusion version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _k_diffusion_available: + return False + return compare_versions(parse(_k_diffusion_version), operation, version) + + class OptionalDependencyNotAvailable(BaseException): """An error indicating that an optional dependency of Diffusers was not found in the environment."""