[Tests] Add MPS skip decorator (#2362)

* finish

* Apply suggestions from code review

* fix indent and import error in test_stable_diffusion_depth

---------

Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
Patrick von Platen 2023-02-15 23:17:25 +02:00 committed by GitHub
parent 2a49fac864
commit 4c52982a0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 19 deletions

View File

@ -79,6 +79,7 @@ if is_torch_available():
parse_flag_from_env,
print_tensor_test,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
torch_device,

View File

@ -163,6 +163,11 @@ def require_torch_gpu(test_case):
)
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed

View File

@ -39,9 +39,8 @@ from diffusers import (
StableDiffusionDepth2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils import floats_tensor, is_accelerate_available, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin
@ -49,7 +48,7 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
@skip_mps
class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionDepth2ImgPipeline
test_save_load_optional_components = False
@ -154,7 +153,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
}
return inputs
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
def test_save_load_local(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@ -248,7 +246,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
def test_dict_tuple_outputs_equivalent(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@ -265,7 +262,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
max_diff = np.abs(output - output_tuple).max()
self.assertLess(max_diff, 1e-4)
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
def test_progress_bar(self):
super().test_progress_bar()

View File

@ -23,7 +23,7 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@ -349,7 +349,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
@ -357,7 +357,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
@ -374,15 +374,15 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_save_load_local(self):
return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_save_load_optional_components(self):
return super().test_save_load_optional_components()

View File

@ -37,7 +37,7 @@ from diffusers import (
)
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import load_image, require_torch_gpu
from diffusers.utils.testing_utils import load_image, require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@ -470,7 +470,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
@ -478,7 +478,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
@ -495,15 +495,15 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
else:
self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_save_load_local(self):
return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
@skip_mps
def test_save_load_optional_components(self):
return super().test_save_load_optional_components()