Check k-diffusion version is at least 0.0.12 (#2022)

* Check k-diffusion version is at least 0.0.12

* make style
This commit is contained in:
Pedro Cuenca 2023-01-17 21:16:46 +01:00 committed by GitHub
parent a43bdd01cd
commit 7e29b747f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 22 additions and 5 deletions

View File

@ -91,7 +91,7 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2", "jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65", "jaxlib>=0.1.65",
"k-diffusion", "k-diffusion>=0.0.12",
"librosa", "librosa",
"modelcards>=0.1.4", "modelcards>=0.1.4",
"numpy", "numpy",

View File

@ -6,6 +6,7 @@ from .utils import (
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version,
is_librosa_available, is_librosa_available,
is_onnx_available, is_onnx_available,
is_scipy_available, is_scipy_available,

View File

@ -4,7 +4,7 @@
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0", "accelerate": "accelerate>=0.11.0",
"black": "black==22.8", "black": "black==22.12",
"datasets": "datasets", "datasets": "datasets",
"filelock": "filelock", "filelock": "filelock",
"flake8": "flake8>=3.8.3", "flake8": "flake8>=3.8.3",
@ -15,7 +15,7 @@ deps = {
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2", "jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65", "jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion", "k-diffusion": "k-diffusion>=0.0.12",
"librosa": "librosa", "librosa": "librosa",
"modelcards": "modelcards>=0.1.4", "modelcards": "modelcards>=0.1.4",
"numpy": "numpy", "numpy": "numpy",

View File

@ -11,6 +11,7 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
is_flax_available, is_flax_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version,
is_onnx_available, is_onnx_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
@ -64,7 +65,7 @@ else:
try: try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): if not (is_torch_available() and is_transformers_available() and is_k_diffusion_version(">=", "0.0.12")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403

View File

@ -47,6 +47,7 @@ from .import_utils import (
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version,
is_librosa_available, is_librosa_available,
is_modelcards_available, is_modelcards_available,
is_onnx_available, is_onnx_available,

View File

@ -427,12 +427,26 @@ def is_transformers_version(operation: str, version: str):
operation (`str`): operation (`str`):
A string representation of an operator, such as `">"` or `"<="` A string representation of an operator, such as `">"` or `"<="`
version (`str`): version (`str`):
A string version of PyTorch A version string
""" """
if not _transformers_available: if not _transformers_available:
return False return False
return compare_versions(parse(_transformers_version), operation, version) return compare_versions(parse(_transformers_version), operation, version)
def is_k_diffusion_version(operation: str, version: str):
"""
Args:
Compares the current k-diffusion version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _k_diffusion_available:
return False
return compare_versions(parse(_k_diffusion_version), operation, version)
class OptionalDependencyNotAvailable(BaseException): class OptionalDependencyNotAvailable(BaseException):
"""An error indicating that an optional dependency of Diffusers was not found in the environment.""" """An error indicating that an optional dependency of Diffusers was not found in the environment."""