From 0f77139253f5481d62f7c1eddc220355440b2d1f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 15 Aug 2023 14:24:55 -0400 Subject: [PATCH] Fix inpaint upload for alpha masks, create reusable function --- modules/img2img.py | 2 +- modules/processing.py | 10 +++++++++- modules/ui.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index ac9fd3f84..328cb0e9e 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,7 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + mask = processing.create_binary_mask(mask) image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch diff --git a/modules/processing.py b/modules/processing.py index 1d098302e..e62db62fd 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -81,6 +81,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image +def create_binary_mask(image): + if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.convert('L') + return image def txt2img_image_conditioning(sd_model, x, width, height): if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models @@ -1385,7 +1391,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = self.image_mask if image_mask is not None: - image_mask = image_mask.convert('L') + # image_mask is passed in as RGBA by Gradio to support alpha masks, + # but we still want to support binary masks. + image_mask = create_binary_mask(image_mask) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) diff --git a/modules/ui.py b/modules/ui.py index a6b1f964b..c98d98496 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -598,7 +598,7 @@ def create_ui(): with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask") with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''