[Tests] Add accelerate to testing (#729)
* fix accelerate for testing * fix copies * uP
This commit is contained in:
parent
7265dd8cc8
commit
a8a3a20d36
11
setup.py
11
setup.py
|
@ -104,7 +104,6 @@ _deps = [
|
||||||
"torch>=1.4",
|
"torch>=1.4",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"transformers>=4.21.0",
|
"transformers>=4.21.0",
|
||||||
"accelerate>=0.12.0"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# this is a lookup table with items like:
|
# this is a lookup table with items like:
|
||||||
|
@ -179,7 +178,15 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||||
extras["docs"] = deps_list("hf-doc-builder")
|
extras["docs"] = deps_list("hf-doc-builder")
|
||||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||||
extras["test"] = deps_list(
|
extras["test"] = deps_list(
|
||||||
"datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"
|
"accelerate",
|
||||||
|
"datasets",
|
||||||
|
"onnxruntime",
|
||||||
|
"pytest",
|
||||||
|
"pytest-timeout",
|
||||||
|
"pytest-xdist",
|
||||||
|
"scipy",
|
||||||
|
"torchvision",
|
||||||
|
"transformers"
|
||||||
)
|
)
|
||||||
extras["torch"] = deps_list("torch")
|
extras["torch"] = deps_list("torch")
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,13 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxSchedulerMixin(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
|
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
|
||||||
_backends = ["flax"]
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue