[Tests] Add MPS skip decorator (#2362)
* finish * Apply suggestions from code review * fix indent and import error in test_stable_diffusion_depth --------- Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
parent
2a49fac864
commit
4c52982a0b
|
@ -79,6 +79,7 @@ if is_torch_available():
|
|||
parse_flag_from_env,
|
||||
print_tensor_test,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
|
|
|
@ -163,6 +163,11 @@ def require_torch_gpu(test_case):
|
|||
)
|
||||
|
||||
|
||||
def skip_mps(test_case):
|
||||
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
||||
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
|
||||
|
|
|
@ -39,9 +39,8 @@ from diffusers import (
|
|||
StableDiffusionDepth2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from diffusers.utils import floats_tensor, is_accelerate_available, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
@ -49,7 +48,7 @@ from ...test_pipelines_common import PipelineTesterMixin
|
|||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
|
||||
@skip_mps
|
||||
class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionDepth2ImgPipeline
|
||||
test_save_load_optional_components = False
|
||||
|
@ -154,7 +153,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
|||
}
|
||||
return inputs
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
|
||||
def test_save_load_local(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
@ -248,7 +246,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
|||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results")
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
@ -265,7 +262,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
|||
max_diff = np.abs(output - output_tuple).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet")
|
||||
def test_progress_bar(self):
|
||||
super().test_progress_bar()
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
|
|||
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.utils import load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
@ -349,7 +349,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
|
||||
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
|
||||
# because UnCLIP GPU undeterminism requires a looser check.
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
|
||||
|
@ -357,7 +357,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
|
||||
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
|
||||
# because UnCLIP undeterminism requires a looser check.
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
|
@ -374,15 +374,15 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
|||
else:
|
||||
self._test_inference_batch_consistent()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
return super().test_dict_tuple_outputs_equivalent()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_save_load_local(self):
|
||||
return super().test_save_load_local()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_save_load_optional_components(self):
|
||||
return super().test_save_load_optional_components()
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ from diffusers import (
|
|||
)
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import load_image, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import load_image, require_torch_gpu, skip_mps
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
@ -470,7 +470,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
|||
|
||||
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
|
||||
# because UnCLIP GPU undeterminism requires a looser check.
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
|
||||
|
@ -478,7 +478,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
|||
|
||||
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
|
||||
# because UnCLIP undeterminism requires a looser check.
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
relax_max_difference = True
|
||||
|
@ -495,15 +495,15 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
|||
else:
|
||||
self._test_inference_batch_consistent()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
return super().test_dict_tuple_outputs_equivalent()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_save_load_local(self):
|
||||
return super().test_save_load_local()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
|
||||
@skip_mps
|
||||
def test_save_load_optional_components(self):
|
||||
return super().test_save_load_optional_components()
|
||||
|
||||
|
|
Loading…
Reference in New Issue