[Tests] Add accelerate to testing (#729)
* fix accelerate for testing * fix copies * uP
This commit is contained in:
parent
7265dd8cc8
commit
a8a3a20d36
11
setup.py
11
setup.py
|
@ -104,7 +104,6 @@ _deps = [
|
|||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.21.0",
|
||||
"accelerate>=0.12.0"
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
|
@ -179,7 +178,15 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
|||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||
extras["test"] = deps_list(
|
||||
"datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"onnxruntime",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"scipy",
|
||||
"torchvision",
|
||||
"transformers"
|
||||
)
|
||||
extras["torch"] = deps_list("torch")
|
||||
|
||||
|
|
|
@ -67,6 +67,13 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
|
|||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxSchedulerMixin(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
|
Loading…
Reference in New Issue