Correct fast tests (#2314)

* correct some

* Apply suggestions from code review

* correct

* Update tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py

* Final
This commit is contained in:
Patrick von Platen 2023-02-10 15:12:34 +02:00 committed by GitHub
parent 716286f19d
commit 96c2279bcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 11 deletions

View File

@ -159,12 +159,10 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array( expected_slice = np.array([0.4115, 0.3870, 0.4089, 0.4807, 0.4668, 0.4144, 0.4151, 0.4721, 0.4569])
[0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-3
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU") @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_img2img_fp16(self): def test_stable_diffusion_img2img_fp16(self):

View File

@ -288,7 +288,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps": if torch_device == "mps":
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546]) expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
else: else:
expected_slice = np.array([0.6854, 0.3740, 0.4857, 0.7130, 0.7403, 0.5536, 0.4829, 0.6182, 0.5053]) expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@ -309,7 +309,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps": if torch_device == "mps":
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335]) expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
else: else:
expected_slice = np.array([0.6074, 0.3096, 0.4802, 0.7463, 0.7388, 0.5393, 0.4531, 0.5928, 0.4972]) expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@ -331,7 +331,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps": if torch_device == "mps":
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551]) expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
else: else:
expected_slice = np.array([0.6681, 0.5023, 0.6611, 0.7605, 0.5724, 0.7959, 0.7240, 0.5871, 0.5383]) expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@ -386,7 +386,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
if torch_device == "mps": if torch_device == "mps":
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439]) expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
else: else:
expected_slice = np.array([0.6853, 0.3740, 0.4856, 0.7130, 0.7402, 0.5535, 0.4828, 0.6182, 0.5053]) expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

View File

@ -17,7 +17,8 @@ from collections import defaultdict
def overwrite_file(file, class_name, test_name, correct_line, done_test): def overwrite_file(file, class_name, test_name, correct_line, done_test):
done_test[file] += 1 _id = f"{file}_{class_name}_{test_name}"
done_test[_id] += 1
with open(file, "r") as f: with open(file, "r") as f:
lines = f.readlines() lines = f.readlines()
@ -43,7 +44,7 @@ def overwrite_file(file, class_name, test_name, correct_line, done_test):
spaces = len(line.split(correct_line.split()[0])[0]) spaces = len(line.split(correct_line.split()[0])[0])
count += 1 count += 1
if count == done_test[file]: if count == done_test[_id]:
in_line = True in_line = True
if in_class and in_func and in_line: if in_class and in_func and in_line: