added first version of inpainting

fixed flag option
This commit is contained in:
AUTOMATIC 2022-08-30 12:55:38 +03:00
parent 587db9c420
commit 54f74d4472
1 changed files with 72 additions and 10 deletions

View File

@ -9,7 +9,7 @@ import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ImageFilter, ImageOps
from torch import autocast
import mimetypes
import random
@ -158,6 +158,7 @@ class Options:
"samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
"grid_save": OptionInfo(True, "Save image grids"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
@ -957,6 +958,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
return_grid = opts.return_grid
if p.prompt_matrix:
grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
@ -967,10 +970,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid)
return_grid = True
else:
grid = image_grid(output_images, p.batch_size)
if return_grid:
output_images.insert(0, grid)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1
@ -1042,7 +1048,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function
prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, code, images, seed, comment = flag_data
prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
filenames = []
@ -1067,7 +1073,7 @@ class Flagging(gr.FlaggingCallback):
filenames.append(filename)
writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])
writer.writerow([prompt, seed, width, height, cfg_scale, steps, filenames[0]])
print("Logged:", filenames[0])
@ -1097,27 +1103,64 @@ txt2img_interface = gr.Interface(
flagging_callback=Flagging()
)
def fill(image, mask):
image_mod = Image.new('RGBA', (image.width, image.height))
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
image_masked = image_masked.convert('RGBa')
for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
for _ in range(repeats):
image_mod.alpha_composite(blurred)
return image_mod.convert("RGB")
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs):
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.original_mask = mask
self.mask_blur = mask_blur
self.mask = None
self.nmask = None
def init(self):
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
if self.original_mask is not None:
if self.mask_blur > 0:
self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
latmask = self.original_mask.convert('RGB').resize((64, 64))
latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255
latmask = latmask[0]
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
imgs = []
for img in self.init_images:
image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height)
if self.original_mask is not None
image = fill(image, self.original_mask)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
imgs.append(image)
if len(imgs) == 1:
@ -1139,16 +1182,33 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
noise = x * sigmas[self.steps - t_enc - 1]
xi = self.init_latent + noise
sigma_sched = sigmas[self.steps - t_enc - 1:]
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
#if self.mask is not None:
# xi = xi * self.mask + noise * self.nmask
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None)
if self.mask is not None:
samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask
return samples_ddim
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
if init_img_with_mask is not None:
image = init_img_with_mask['image']
mask = init_img_with_mask['mask']
else:
image = init_img
mask = None
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
@ -1164,7 +1224,8 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
init_images=[init_img],
init_images=[image],
mask=mask,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
extra_generation_params={"Denoising Strength": denoising_strength}
@ -1262,7 +1323,8 @@ img2img_interface = gr.Interface(
wrap_gradio_call(img2img),
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil"),
gr.Image(label="Image for inpainting with mask", source="upload", interactive=True, type="pil", tool="sketch"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),