Relax DiT test (#2808)

* Relax DiT test

* relax 2 more tests

* fix style

* skip test on mac due to older protobuf
This commit is contained in:
Kashif Rasul 2023-03-24 11:28:55 +01:00 committed by GitHub
parent 37a44bb283
commit f6feb69991
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 3 deletions

View File

@ -20,7 +20,7 @@ import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
from diffusers.utils import load_numpy, slow
from diffusers.utils import is_xformers_available, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ...pipeline_params import (
@ -97,7 +97,14 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertLessEqual(max_diff, 1e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(relax_max_difference=True)
self._test_inference_batch_single_identical(relax_max_difference=True, expected_max_diff=1e-3)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
@require_torch_gpu
@ -123,7 +130,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
expected_image = load_numpy(
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy"
)
assert np.abs((expected_image - image).max()) < 1e-3
assert np.abs((expected_image - image).max()) < 1e-2
def test_dit_512(self):
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")

View File

@ -153,6 +153,10 @@ class SpectrogramDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def test_inference_batch_consistent(self):
pass
@skip_mps
def test_progress_bar(self):
return super().test_progress_bar()
@slow
@require_torch_gpu