From 3bec90ff2c50df133ced407d6c451c22e5155342 Mon Sep 17 00:00:00 2001 From: Victor Schmidt <9283470+vict0rsch@users.noreply.github.com> Date: Sun, 20 Nov 2022 13:33:09 -0500 Subject: [PATCH] Handle batches and Tensors in `pipeline_stable_diffusion_inpaint.py:prepare_mask_and_masked_image` (#1003) * Handle batches and Tensors in `prepare_mask_and_masked_image` * `blackfy` upgrade `black` * handle mask as `np.array` * add docstring * revert `black` changes with smaller line length * missing ValueError in docstring * raise `TypeError` for image as tensor but not mask * typo in mask shape selection * check for batch dim * fix: wrong indentation * add tests Co-authored-by: Patrick von Platen --- .../pipeline_stable_diffusion_inpaint.py | 95 +++++++++- .../test_stable_diffusion_inpaint.py | 171 ++++++++++++++++++ 2 files changed, 257 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 992b4ca2..f1a613aa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -35,16 +35,93 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def prepare_mask_and_masked_image(image, mask): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. + This means that those inputs will be converted to ``torch.Tensor`` with + shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) + The ``image`` will be converted to ``torch.float32`` and normalized to be in + ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to + ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` + or a ``channels x height x width`` ``torch.Tensor`` or a + ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or + a ``1 x height x width`` ``torch.Tensor`` or a + ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. + ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. + ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + if isinstance(mask, PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index c6a976f4..5c46d909 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -29,7 +29,10 @@ from diffusers import ( UNet2DModel, VQModel, ) + from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image + from diffusers.utils.testing_utils import require_torch_gpu from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -506,3 +509,171 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): mem_bytes = torch.cuda.max_memory_allocated() # make sure that less than 2.2 GB is allocated assert mem_bytes < 2.2 * 10**9 + +class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): + def test_pil_inputs(self): + im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + im = Image.fromarray(im) + mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + t_mask, t_masked = prepare_mask_and_masked_image(im, mask) + + self.assertTrue(isinstance(t_mask, torch.Tensor)) + self.assertTrue(isinstance(t_masked, torch.Tensor)) + + self.assertEqual(t_mask.ndim, 4) + self.assertEqual(t_masked.ndim, 4) + + self.assertEqual(t_mask.shape, (1, 1, 32, 32)) + self.assertEqual(t_masked.shape, (1, 3, 32, 32)) + + self.assertTrue(t_mask.dtype == torch.float32) + self.assertTrue(t_masked.dtype == torch.float32) + + self.assertTrue(t_mask.min() >= 0.0) + self.assertTrue(t_mask.max() <= 1.0) + self.assertTrue(t_masked.min() >= -1.0) + self.assertTrue(t_masked.min() <= 1.0) + + self.assertTrue(t_mask.sum() > 0.0) + + def test_np_inputs(self): + im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + im_pil = Image.fromarray(im_np) + mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) + + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) + + self.assertTrue((t_mask_np == t_mask_pil).all()) + self.assertTrue((t_masked_np == t_masked_pil).all()) + + def test_torch_3D_2D_inputs(self): + im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy().transpose(1, 2, 0) + mask_np = mask_tensor.numpy() + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_3D_3D_inputs(self): + im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy().transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_2D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy() + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_3D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_4D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0][0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_batch_4D_3D(self): + im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5 + + im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] + mask_nps = [mask.numpy() for mask in mask_tensor] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_np = torch.cat([n[0] for n in nps]) + t_masked_np = torch.cat([n[1] for n in nps]) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_batch_4D_4D(self): + im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5 + + im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] + mask_nps = [mask.numpy()[0] for mask in mask_tensor] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_np = torch.cat([n[0] for n in nps]) + t_masked_np = torch.cat([n[1] for n in nps]) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_shape_mismatch(self): + # test height and width + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) + # test batch dim + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) + # test batch dim + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) + + def test_type_mismatch(self): + # test tensors-only + with self.assertRaises(TypeError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) + # test tensors-only + with self.assertRaises(TypeError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) + + def test_channels_first(self): + # test channels first for 3D tensors + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) + + def test_tensor_range(self): + # test im <= 1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) + # test im >= -1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) + # test mask <= 1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) + # test mask >= 0 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) \ No newline at end of file