make style
This commit is contained in:
parent
2b31740d54
commit
ab1f01e634
|
@ -36,30 +36,25 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
def prepare_mask_and_masked_image(image, mask):
|
def prepare_mask_and_masked_image(image, mask):
|
||||||
"""
|
"""
|
||||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
|
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||||
This means that those inputs will be converted to ``torch.Tensor`` with
|
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||||
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
|
``image`` and ``1`` for the ``mask``.
|
||||||
the ``image`` and ``1`` for the ``mask``.
|
|
||||||
|
|
||||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in
|
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||||
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
|
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||||
``torch.float32`` too.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
|
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||||
or a ``channels x height x width`` ``torch.Tensor`` or a
|
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||||
``batch x channels x height x width`` ``torch.Tensor``.
|
|
||||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
|
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||||
a ``1 x height x width`` ``torch.Tensor`` or a
|
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||||
``batch x 1 x height x width`` ``torch.Tensor``.
|
|
||||||
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
|
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||||
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
|
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||||
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
|
||||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||||
(ot the other way around).
|
(ot the other way around).
|
||||||
|
|
||||||
|
|
|
@ -29,10 +29,8 @@ from diffusers import (
|
||||||
UNet2DModel,
|
UNet2DModel,
|
||||||
VQModel,
|
VQModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||||
|
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import require_torch_gpu
|
from diffusers.utils.testing_utils import require_torch_gpu
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
@ -510,6 +508,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||||
# make sure that less than 2.2 GB is allocated
|
# make sure that less than 2.2 GB is allocated
|
||||||
assert mem_bytes < 2.2 * 10**9
|
assert mem_bytes < 2.2 * 10**9
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
|
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
|
||||||
def test_pil_inputs(self):
|
def test_pil_inputs(self):
|
||||||
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
|
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
|
||||||
|
|
Loading…
Reference in New Issue