parent
f4781a0b27
commit
721e017401
18
setup.py
18
setup.py
|
@ -68,6 +68,7 @@ To create the package for pypi.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
from distutils.core import Command
|
from distutils.core import Command
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
@ -82,10 +83,13 @@ _deps = [
|
||||||
"datasets",
|
"datasets",
|
||||||
"filelock",
|
"filelock",
|
||||||
"flake8>=3.8.3",
|
"flake8>=3.8.3",
|
||||||
|
"flax>=0.4.1",
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.8.1",
|
"huggingface-hub>=0.8.1",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
|
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||||
|
"jaxlib>=0.1.65,<=0.3.6",
|
||||||
"modelcards==0.1.4",
|
"modelcards==0.1.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pytest",
|
"pytest",
|
||||||
|
@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui
|
||||||
extras["docs"] = ["hf-doc-builder"]
|
extras["docs"] = ["hf-doc-builder"]
|
||||||
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
||||||
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
|
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
|
||||||
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
|
extras["torch"] = deps_list("torch")
|
||||||
|
|
||||||
|
if os.name == "nt": # windows
|
||||||
|
extras["flax"] = [] # jax is not supported on windows
|
||||||
|
else:
|
||||||
|
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
||||||
|
|
||||||
|
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
|
||||||
|
|
||||||
install_requires = [
|
install_requires = [
|
||||||
deps["importlib_metadata"],
|
deps["importlib_metadata"],
|
||||||
|
@ -180,13 +191,12 @@ install_requires = [
|
||||||
deps["numpy"],
|
deps["numpy"],
|
||||||
deps["regex"],
|
deps["regex"],
|
||||||
deps["requests"],
|
deps["requests"],
|
||||||
deps["torch"],
|
|
||||||
deps["Pillow"],
|
deps["Pillow"],
|
||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="Diffusers",
|
description="Diffusers",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@ -198,7 +208,7 @@ setup(
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
python_requires=">=3.6.0",
|
python_requires=">=3.7.0",
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
||||||
|
|
|
@ -2,6 +2,7 @@ from .utils import (
|
||||||
is_inflect_available,
|
is_inflect_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
|
is_torch_available,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
is_unidecode_available,
|
is_unidecode_available,
|
||||||
)
|
)
|
||||||
|
@ -10,40 +11,42 @@ from .utils import (
|
||||||
__version__ = "0.4.0.dev0"
|
__version__ = "0.4.0.dev0"
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .modeling_utils import ModelMixin
|
|
||||||
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
|
||||||
from .onnx_utils import OnnxRuntimeModel
|
from .onnx_utils import OnnxRuntimeModel
|
||||||
from .optimization import (
|
|
||||||
get_constant_schedule,
|
|
||||||
get_constant_schedule_with_warmup,
|
|
||||||
get_cosine_schedule_with_warmup,
|
|
||||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
|
||||||
get_scheduler,
|
|
||||||
)
|
|
||||||
from .pipeline_utils import DiffusionPipeline
|
|
||||||
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
|
||||||
from .schedulers import (
|
|
||||||
DDIMScheduler,
|
|
||||||
DDPMScheduler,
|
|
||||||
KarrasVeScheduler,
|
|
||||||
PNDMScheduler,
|
|
||||||
SchedulerMixin,
|
|
||||||
ScoreSdeVeScheduler,
|
|
||||||
)
|
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_scipy_available():
|
if is_torch_available():
|
||||||
|
from .modeling_utils import ModelMixin
|
||||||
|
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||||
|
from .optimization import (
|
||||||
|
get_constant_schedule,
|
||||||
|
get_constant_schedule_with_warmup,
|
||||||
|
get_cosine_schedule_with_warmup,
|
||||||
|
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
|
get_scheduler,
|
||||||
|
)
|
||||||
|
from .pipeline_utils import DiffusionPipeline
|
||||||
|
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||||
|
from .schedulers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
DDPMScheduler,
|
||||||
|
KarrasVeScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
SchedulerMixin,
|
||||||
|
ScoreSdeVeScheduler,
|
||||||
|
)
|
||||||
|
from .training_utils import EMAModel
|
||||||
|
else:
|
||||||
|
from .utils.dummy_pt_objects import * # noqa F403
|
||||||
|
|
||||||
|
if is_torch_available() and is_scipy_available():
|
||||||
from .schedulers import LMSDiscreteScheduler
|
from .schedulers import LMSDiscreteScheduler
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_scipy_objects import * # noqa F403
|
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
||||||
|
|
||||||
from .training_utils import EMAModel
|
if is_torch_available() and is_transformers_available():
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
LDMTextToImagePipeline,
|
LDMTextToImagePipeline,
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
|
@ -51,10 +54,9 @@ if is_transformers_available():
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_transformers_objects import * # noqa F403
|
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
|
||||||
|
if is_torch_available() and is_transformers_available() and is_onnx_available():
|
||||||
if is_transformers_available() and is_onnx_available():
|
|
||||||
from .pipelines import StableDiffusionOnnxPipeline
|
from .pipelines import StableDiffusionOnnxPipeline
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
|
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
||||||
|
|
|
@ -8,10 +8,13 @@ deps = {
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
"flake8": "flake8>=3.8.3",
|
"flake8": "flake8>=3.8.3",
|
||||||
|
"flax": "flax>=0.4.1",
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.8.1",
|
"huggingface-hub": "huggingface-hub>=0.8.1",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
|
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||||
|
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||||
"modelcards": "modelcards==0.1.4",
|
"modelcards": "modelcards==0.1.4",
|
||||||
"numpy": "numpy",
|
"numpy": "numpy",
|
||||||
"pytest": "pytest",
|
"pytest": "pytest",
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMixin(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2DConditionModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2DModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_constant_schedule(*args, **kwargs):
|
||||||
|
requires_backends(get_constant_schedule, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_constant_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_constant_schedule_with_warmup, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_cosine_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_cosine_schedule_with_warmup, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_linear_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_linear_schedule_with_warmup, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
|
||||||
|
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(*args, **kwargs):
|
||||||
|
requires_backends(get_scheduler, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDPMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class KarrasVePipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LDMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PNDMPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DDPMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class KarrasVeScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PNDMScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EMAModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
|
@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
class LMSDiscreteScheduler(metaclass=DummyObject):
|
class LMSDiscreteScheduler(metaclass=DummyObject):
|
||||||
_backends = ["scipy"]
|
_backends = ["torch", "scipy"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["scipy"])
|
requires_backends(self, ["torch", "scipy"])
|
|
@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
|
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
|
||||||
_backends = ["transformers", "onnx"]
|
_backends = ["torch", "transformers", "onnx"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["transformers", "onnx"])
|
requires_backends(self, ["torch", "transformers", "onnx"])
|
|
@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
class LDMTextToImagePipeline(metaclass=DummyObject):
|
class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||||
_backends = ["transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["transformers"])
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||||
_backends = ["transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["transformers"])
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionInpaintPipeline(metaclass=DummyObject):
|
class StableDiffusionInpaintPipeline(metaclass=DummyObject):
|
||||||
_backends = ["transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["transformers"])
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionPipeline(metaclass=DummyObject):
|
class StableDiffusionPipeline(metaclass=DummyObject):
|
||||||
_backends = ["transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["transformers"])
|
requires_backends(self, ["torch", "transformers"])
|
|
@ -1,10 +0,0 @@
|
||||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
|
||||||
# flake8: noqa
|
|
||||||
from ..utils import DummyObject, requires_backends
|
|
||||||
|
|
||||||
|
|
||||||
class GradTTSPipeline(metaclass=DummyObject):
|
|
||||||
_backends = ["transformers", "inflect", "unidecode"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["transformers", "inflect", "unidecode"])
|
|
Loading…
Reference in New Issue