[Flax] Make room for more frameworks (#494)

* start

* finish
This commit is contained in:
Patrick von Platen 2022-09-13 13:24:27 +02:00 committed by GitHub
parent f4781a0b27
commit 721e017401
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 227 additions and 57 deletions

View File

@ -68,6 +68,7 @@ To create the package for pypi.
"""
import re
import os
from distutils.core import Command
from setuptools import find_packages, setup
@ -82,10 +83,13 @@ _deps = [
"datasets",
"filelock",
"flake8>=3.8.3",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.8.1",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4",
"numpy",
"pytest",
@ -171,7 +175,14 @@ extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-bui
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
extras["torch"] = deps_list("torch")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
install_requires = [
deps["importlib_metadata"],
@ -180,7 +191,6 @@ install_requires = [
deps["numpy"],
deps["regex"],
deps["requests"],
deps["torch"],
deps["Pillow"],
]
@ -198,7 +208,7 @@ setup(
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
python_requires=">=3.6.0",
python_requires=">=3.7.0",
install_requires=install_requires,
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},

View File

@ -2,6 +2,7 @@ from .utils import (
is_inflect_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
is_transformers_available,
is_unidecode_available,
)
@ -10,10 +11,14 @@ from .utils import (
__version__ = "0.4.0.dev0"
from .configuration_utils import ConfigMixin
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .onnx_utils import OnnxRuntimeModel
from .optimization import (
from .utils import logging
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
@ -21,29 +26,27 @@ from .optimization import (
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
from .utils import logging
)
from .training_utils import EMAModel
else:
from .utils.dummy_pt_objects import * # noqa F403
if is_scipy_available():
if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
else:
from .utils.dummy_scipy_objects import * # noqa F403
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
from .training_utils import EMAModel
if is_transformers_available():
if is_torch_available() and is_transformers_available():
from .pipelines import (
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
@ -51,10 +54,9 @@ if is_transformers_available():
StableDiffusionPipeline,
)
else:
from .utils.dummy_transformers_objects import * # noqa F403
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
if is_transformers_available() and is_onnx_available():
if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline
else:
from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

View File

@ -8,10 +8,13 @@ deps = {
"datasets": "datasets",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.8.1",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4",
"numpy": "numpy",
"pytest": "pytest",

View File

@ -0,0 +1,165 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UNet2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VQModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
def get_constant_schedule_with_warmup(*args, **kwargs):
requires_backends(get_constant_schedule_with_warmup, ["torch"])
def get_cosine_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_schedule_with_warmup, ["torch"])
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch"])
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])
def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch"])
class DiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class LMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["scipy"]
_backends = ["torch", "scipy"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])
requires_backends(self, ["torch", "scipy"])

View File

@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
_backends = ["transformers", "onnx"]
_backends = ["torch", "transformers", "onnx"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers", "onnx"])
requires_backends(self, ["torch", "transformers", "onnx"])

View File

@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends
class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])
class StableDiffusionInpaintPipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])
class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"]
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
requires_backends(self, ["torch", "transformers"])

View File

@ -1,10 +0,0 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class GradTTSPipeline(metaclass=DummyObject):
_backends = ["transformers", "inflect", "unidecode"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers", "inflect", "unidecode"])