Skip `mps` in text-to-video tests (#2792)

* Skip mps in text-to-video tests.

* style

* Skip UNet3D mps tests.
This commit is contained in:
Pedro Cuenca 2023-03-23 13:39:03 +00:00 committed by GitHub
parent dc5b4e2342
commit aa0531fa8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 1 deletions

View File

@ -23,6 +23,7 @@ from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.utils import ( from diffusers.utils import (
floats_tensor, floats_tensor,
logging, logging,
skip_mps,
torch_device, torch_device,
) )
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
@ -60,6 +61,7 @@ def create_lora_layers(model):
return lora_attn_procs return lora_attn_procs
@skip_mps
class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel model_class = UNet3DConditionModel

View File

@ -35,6 +35,7 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps
class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = TextToVideoSDPipeline pipeline_class = TextToVideoSDPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
@ -155,12 +156,12 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_num_images_per_prompt(self): def test_num_images_per_prompt(self):
pass pass
@skip_mps
def test_progress_bar(self): def test_progress_bar(self):
return super().test_progress_bar() return super().test_progress_bar()
@slow @slow
@skip_mps
class TextToVideoSDPipelineSlowTests(unittest.TestCase): class TextToVideoSDPipelineSlowTests(unittest.TestCase):
def test_full_model(self): def test_full_model(self):
expected_video = load_numpy( expected_video = load_numpy(