[Tests] Add accelerate to testing (#729)

* fix accelerate for testing

* fix copies

* uP
This commit is contained in:
Patrick von Platen 2022-10-05 11:35:02 +02:00 committed by GitHub
parent 7265dd8cc8
commit a8a3a20d36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -104,7 +104,6 @@ _deps = [
"torch>=1.4", "torch>=1.4",
"torchvision", "torchvision",
"transformers>=4.21.0", "transformers>=4.21.0",
"accelerate>=0.12.0"
] ]
# this is a lookup table with items like: # this is a lookup table with items like:
@ -179,7 +178,15 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list( extras["test"] = deps_list(
"datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers" "accelerate",
"datasets",
"onnxruntime",
"pytest",
"pytest-timeout",
"pytest-xdist",
"scipy",
"torchvision",
"transformers"
) )
extras["torch"] = deps_list("torch") extras["torch"] = deps_list("torch")

View File

@ -67,6 +67,13 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxSchedulerMixin(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxScoreSdeVeScheduler(metaclass=DummyObject): class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]